I'm learning keras API in tensorflow(2.3). In this guide on tensorflow website, I found an example of custom loss funciton:
def custom_mean_squared_error(y_true, y_pred):
return tf.math.reduce_mean(tf.square(y_true - y_pred))
The reduce_mean
function in this custom loss function will return an scalar.
Is it right to define loss function like this? As far as I know, the first dimension of the shapes of y_true
and y_pred
is the batch size. I think the loss function should return loss values for every sample in the batch. So the loss function shoud give an array of shape (batch_size,)
. But the above function gives a single value for the whole batch.
Maybe the above example is wrong? Could anyone give me some help on this problem?
p.s. Why do I think the loss function should return an array rather than a single value?
I read the source code of Model class. When you provide a loss function (please note it's a function, not a loss class) to Model.compile()
method, ths loss function is used to construct a LossesContainer
object, which is stored in Model.compiled_loss
. This loss function passed to the constructor of LossesContainer
class is used once again to construct a LossFunctionWrapper
object, which is stored in LossesContainer._losses
.
According to the source code of LossFunctionWrapper class, the overall loss value for a training batch is calculated by the LossFunctionWrapper.__call__()
method (inherited from Loss
class), i.e. it returns a single loss value for the whole batch. But the LossFunctionWrapper.__call__()
first calls the LossFunctionWrapper.call()
method to obtain an array of losses for every sample in the training batch. Then these losses are fianlly averaged to get the single loss value for the whole batch. It's in the LossFunctionWrapper.call()
method that the loss function provided to the Model.compile()
method is called.
That's why I think the custom loss funciton should return an array of losses, insead of a single scalar value. Besides, if we write a custom Loss
class for the Model.compile()
method, the call()
method of our custom Loss
class should also return an array, rather than a signal value.
I opened an issue on github. It's confirmed that custom loss function is required to return one loss value per sample. The example will need to be updated to reflect this.
See Question&Answers more detail:
os 与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…