"""Loss functions for neural network training.This module contains implementations of common loss functions used in neuralnetwork training, such as Mean Squared Error and Cross Entropy.Classes: MeanSquaredError: Calculates the Mean Squared Error loss. CrossEntropy: Calculates the Cross Entropy loss."""importloggingfromtricycle.contextimportTRICYCLE_CONTEXTfromtricycle.opsimportOpfromtricycle.tensorimportTensorlogger=logging.getLogger(__name__)
[docs]classMeanSquaredError(Op):"""Calculates Mean Squared Error loss. This class implements the Mean Squared Error (MSE) loss function, which measures the average squared difference between the predicted and true values. Attributes: diff: The difference between predicted and true values. divisor: A scaling factor for the loss calculation. """
[docs]defbackward(self,grad:Tensor)->Tensor:"""Computes the backward pass for Mean Squared Error loss. Args: grad: A Tensor containing the gradient from the previous layer. Returns: A Tensor containing the computed gradients. """xp=grad.xpifTRICYCLE_CONTEXT.use_mixed_precision:grad.array=grad.array.astype(xp.float32)out=self.diff*2*grad.array*self.divisorifTRICYCLE_CONTEXT.use_mixed_precision:out=out.astype(xp.float16)returnTensor(out)
[docs]defforward(self,y_true:Tensor,y_pred:Tensor)->Tensor:"""Computes the forward pass for Mean Squared Error loss. Args: y_true: A Tensor containing the true values. y_pred: A Tensor containing the predicted values. Returns: A Tensor containing the computed MSE loss. Raises: ValueError: If the computed loss is infinite. """xp=y_pred.xpifTRICYCLE_CONTEXT.use_mixed_precision:y_pred.array=y_pred.array.astype(xp.float32)y_true.array=y_true.array.astype(xp.float32)self.diff=y_pred.array-y_true.arrayself.divisor=1/xp.prod(y_pred.shape[-1])out=(self.diff**2).sum()*self.divisorifTRICYCLE_CONTEXT.use_mixed_precision:out*=TRICYCLE_CONTEXT.loss_scale_factorout=out.astype(xp.float16)ifnotxp.isfinite(out):raiseValueError("Loss is infinite")# only y_pred is differentiable: y_true is a constantreturnTensor(out,args=(y_pred,),back_fns=(self.backward,),name="mean_squared_error",)
[docs]classCrossEntropy(Op):"""Calculates Cross Entropy loss. This class implements the Cross Entropy loss function, which is commonly used for classification tasks. It computes the loss given logits and target indices (as opposed to one-hot encoded tensors). Attributes: _y_true: The true labels (cached for backward pass). _log_softmax_pred: The log softmax of predictions (cached for backward pass). _out: The computed loss (cached for backward pass). _grad: The computed gradients (cached for backward pass). """
[docs]deflog_softmax(self,tensor:Tensor):"""Computes the log softmax of the input tensor. Args: tensor: A Tensor containing the input values. Returns: The log softmax of the input tensor. """xp=tensor.xpx_max=xp.max(tensor.array,axis=-1,keepdims=True)log_sum_exp=x_max+xp.log(xp.sum(xp.exp(tensor.array-x_max),axis=-1,keepdims=True))returntensor.array-log_sum_exp
[docs]defforward(self,y_true:Tensor,y_pred:Tensor)->Tensor:"""Computes the forward pass for Cross Entropy loss. Args: y_true: A Tensor containing the true labels. y_pred: A Tensor containing the predicted logits. Returns: A Tensor containing the computed Cross Entropy loss. Raises: NotImplementedError: If the input tensor has an unsupported number of dimensions. """xp=y_pred.xp# cross entropy reduces a huge matrix to a single number which makes# it really sensitive to errors. To rememdy this, we need to use# full precisionifTRICYCLE_CONTEXT.use_mixed_precision:y_pred.array=y_pred.array.astype(xp.float32)log_softmax_pred=self.log_softmax(y_pred)# Cache for backward passself._y_true=y_true.arrayself._log_softmax_pred=log_softmax_predndim=log_softmax_pred.ndimifndim==3:batch_indices=xp.arange(y_true.shape[0],dtype=int)token_indices=xp.arange(y_true.shape[1],dtype=int)loss=-log_softmax_pred[batch_indices[:,None],token_indices,y_true.array]elifndim==2:indices=xp.arange(y_true.shape[0],dtype=int)loss=-log_softmax_pred[indices,y_true.array]elifndim==1:loss=-log_softmax_pred[y_true.array]else:raiseNotImplementedError(f"BinaryCrossEntropy with predictions with ndim: {ndim} are not yet supported")# Mean loss over all elementsloss=loss.mean()self._out=lossifTRICYCLE_CONTEXT.use_mixed_precision:self._out=(self._out.astype(xp.float16)*TRICYCLE_CONTEXT.loss_scale_factor)returnTensor(self._out,is_batched=False,back_fns=(self.backward,),args=(y_pred,),name="cross_entropy",)
[docs]defbackward(self,grad:Tensor)->Tensor:"""Computes the backward pass for Cross Entropy loss. Args: grad: A Tensor containing the gradient from the previous layer. Returns: A Tensor containing the computed gradients. Raises: NotImplementedError: If the input tensor has an unsupported number of dimensions. """xp=grad.xpndim=self._log_softmax_pred.ndimifTRICYCLE_CONTEXT.use_mixed_precision:grad.array=grad.array.astype(xp.float32)ifndim==3:batch_indices=xp.arange(self._y_true.shape[0],dtype=int)token_indices=xp.arange(self._y_true.shape[1],dtype=int)grad_output=xp.exp(self._log_softmax_pred)grad_output[batch_indices[:,None],token_indices,self._y_true]-=1grad_output*=grad.array/(self._y_true.shape[0]*self._y_true.shape[1])elifndim==2:indices=xp.arange(self._y_true.shape[0],dtype=int)grad_output=xp.exp(self._log_softmax_pred)grad_output[indices,self._y_true]-=1grad_output*=grad.array/self._y_true.shape[0]elifndim==1:grad_output=xp.exp(self._log_softmax_pred)grad_output[self._y_true]-=1grad_output*=grad.arrayelse:raiseNotImplementedError(f"BinaryCrossEntropy with predictions with ndim: {ndim} are not yet supported")self._grad=grad_output# remember to convert the gradient back to the right precisionifTRICYCLE_CONTEXT.use_mixed_precision:self._grad=self._grad.astype(xp.float16)returnTensor(self._grad,is_batched=grad.is_batched)