I stumbled onto this question while looking for an answer and it took me some time to figure out as I use the model subclassing API in TF 2.0 by default (as in here https://www.tensorflow.org/tutorials/quickstart/advanced).
If somebody is in a similar situation, all you need to do is assign the intermediate output you want, as an attribute of the class. Then keep the test_step without the @tf.function decorator and create its decorated copy, say val_step, for efficient internal computation of validation performance during training. As a short example, I have modified a few functions of the tutorial from the link accordingly. I'm assuming we need to access the output after flattening.
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
self.intermediate=x #assign it as an object attribute for accessing later
x = self.d1(x)
return self.d2(x)
#Remove @tf.function decorator from test_step for prediction
def test_step(images, labels):
predictions = model(images, training=False)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
return
#Create a decorated val_step for object's internal use during training
@tf.function
def val_step(images, labels):
return test_step(images, labels)
Now when you run model.predict() after training, using the un-decorated test step, you can access the intermediate output using model.intermediate which would be an EagerTensor whose value is obtained simply by model.intermediate.numpy(). However, if you don't remove the @tf_function decorator from test_step, this would return a Tensor whose value is not so straightforward to obtain.
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…