"""Optimisers for gradient-based optimisation.This module contains various optimiser classes that can be used forgradient-based optimisation of tensors."""fromloggingimportgetLoggerfromwarningsimportwarnfromtricycle.contextimportTRICYCLE_CONTEXTfromtricycle.tensorimportTensorLOGGER=getLogger(__name__)
[docs]classOptimiser:"""Base class for optimisers."""def__call__(self,tensor:Tensor)->Tensor:""" Apply optimisation to the given tensor. Args: tensor (Tensor): The tensor to optimise. Returns: Tensor: The optimised tensor. Raises: NotImplementedError: This method should be implemented by subclasses. """raiseNotImplementedErrordef_reset_grad(self,tensor:Tensor):""" Reset the gradient information of the tensor. Args: tensor (Tensor): The tensor to reset. Returns: Tensor: The tensor with reset gradient information. """tensor.grad=Nonetensor.args=Nonetensor.back_fns=Nonereturntensor
[docs]classStochasticGradientDescent(Optimiser):""" Stochastic Gradient Descent (SGD) optimiser. This optimiser implements SGD with optional weight decay and momentum. Attributes: learning_rate (float): The learning rate for the optimiser. weight_decay (float | None): The weight decay factor. momentum (float | None): The momentum factor. logger: The logger instance. momentum_store (dict): Store for momentum values. """def__init__(self,learning_rate:float,weight_decay:float|None=None,momentum:float|None=None,logger=LOGGER,):""" Initialise the SGD optimiser. Args: learning_rate (float): The learning rate for the optimiser. weight_decay (float | None, optional): The weight decay factor. Defaults to None. momentum (float | None, optional): The momentum factor. Defaults to None. logger (optional): The logger instance. Defaults to LOGGER. """self.learning_rate=learning_rateself.weight_decay=weight_decayself.momentum=momentumself.logger=loggerself.momentum_store={}
[docs]defupdate_weight(self,tensor:Tensor):""" Perform a gradient update on a tensor. This method updates the tensor's weights using the calculated gradients, optionally applying weight decay and momentum. Args: tensor (Tensor): The tensor to update. Returns: Tensor: The updated tensor. """xp=tensor.xpasserttensor.gradisnotNone# We need to do gradient updates in full precision otherwise we get# stability issuesifTRICYCLE_CONTEXT.use_mixed_precision:tensor.grad.array=(tensor.grad.array.astype(xp.float32)/TRICYCLE_CONTEXT.loss_scale_factor)ifnottensor.array.dtype==xp.float32:tensor.array=tensor.array.astype(xp.float32)iftensor.grad.is_batched:tensor.grad=tensor.grad.from_batched().einsum("z...->...")grad=self.learning_rate*tensor.grad.arrayifself.weight_decayisnotNone:wd=self.learning_rate*self.weight_decay*tensor.arraygrad+=wdifself.momentumisnotNoneandself.momentum>0:iftensor._idnotinself.momentum_store:last_momentum=tensor.xp.zeros(grad.shape)else:last_momentum=self.momentum_store[tensor._id]grad+=self.momentum*last_momentumself.momentum_store[tensor._id]=grad# make sure our gradients aren't underflowing or overflowifnotxp.isfinite(grad).all():warn("Found nans in gradient, skipping this gradient and""decreasing loss scaling. If this warning persists, ""check that your learning rate isn't too high")TRICYCLE_CONTEXT.loss_scale_factor/=2self.logger.warn(f"New scaling factor: {TRICYCLE_CONTEXT.loss_scale_factor}")returntensorif(grad==0).sum()>grad.size*0.05:warn("Found too many 0's in gradient, skipping this gradient and""increasing loss scaling. If this warning persists, ""check that your learning rate isn't too low")TRICYCLE_CONTEXT.loss_scale_factor*=2self.logger.warn(f"New scaling factor: {TRICYCLE_CONTEXT.loss_scale_factor}")returntensorifTRICYCLE_CONTEXT.use_mixed_precision:tensor.array-=grad.astype(xp.float32)tensor.grad.array.fill(0)returntensor
def__call__(self,tensor:Tensor)->Tensor:""" Apply the SGD optimisation to the given tensor. Args: tensor (Tensor): The tensor to optimise. Returns: Tensor: The optimised tensor. """returnself._reset_grad(self.update_weight(tensor))
[docs]classAdamW(Optimiser):""" AdamW optimiser. This optimiser implements the AdamW algorithm, which is Adam with weight decay. Attributes: learning_rate (float): The learning rate for the optimiser. betas (tuple): The exponential decay rates for the moment estimates. eps (float): A small constant for numerical stability. weight_decay (float): The weight decay factor. timestep (int): The current time step. momentum (dict): Store for first moment estimates. square_momentum (dict): Store for second moment estimates. """def__init__(self,learning_rate=1e-3,betas=(0.9,0.999),eps=1e-6,weight_decay=0.01,):""" Initialise the AdamW optimiser. Args: learning_rate (float, optional): The learning rate. Defaults to 1e-3. betas (tuple, optional): The exponential decay rates for the moment estimates. Defaults to (0.9, 0.999). eps (float, optional): A small constant for numerical stability. Defaults to 1e-6. weight_decay (float, optional): The weight decay factor. Defaults to 0.01. """self.learning_rate=learning_rateself.betas=betasself.eps=epsself.weight_decay=weight_decayself.timestep=1self.momentum={}self.square_momentum={}
[docs]defstep(self):""" Increase the time step. This method should be called after each optimisation step. """# we compute the updates dynamically so we'll need to remember to# call thisself.timestep+=1
[docs]defupdate_weight(self,tensor:Tensor)->Tensor:""" Perform a weight update on a tensor using the AdamW algorithm. Args: tensor (Tensor): The tensor to update. Returns: Tensor: The updated tensor. """key=tensor._idxp=tensor.xpasserttensor.gradisnotNonegrad=tensor.grad.arrayifTRICYCLE_CONTEXT.use_mixed_precision:grad=grad.astype(xp.float32)/TRICYCLE_CONTEXT.loss_scale_factorifnottensor.array.dtype==xp.float32:tensor.array=tensor.array.astype(xp.float32)# initialise storesifkeynotinself.momentum:self.momentum[key]=xp.zeros_like(grad,dtype=xp.float32)ifkeynotinself.square_momentum:self.square_momentum[key]=tensor.xp.zeros_like(grad,dtype=xp.float32)self.momentum[key]=(self.betas[0]*self.momentum[key]+(1-self.betas[0])*grad)self.square_momentum[key]=self.betas[1]*self.square_momentum[key]+(1-self.betas[1])*(grad*grad)momentum_estimate=self.momentum[key]/(1-self.betas[0]**self.timestep)square_momentum_estimate=self.square_momentum[key]/(1-self.betas[1]**self.timestep)combined_grad=self.learning_rate*(momentum_estimate/(xp.sqrt(square_momentum_estimate)+self.eps)+self.weight_decay*tensor.array)# make sure our gradients aren't underflowing or overflowifnotxp.isfinite(combined_grad).all():warn("Found nans in gradient, skipping this gradient and""decreasing loss scaling. If this warning persists, ""check that your learning rate isn't too high")TRICYCLE_CONTEXT.loss_scale_factor/=2self.logger.warn(f"New scaling factor: {TRICYCLE_CONTEXT.loss_scale_factor}")returntensorif(combined_grad==0).sum()>combined_grad.size*0.05:warn("Found too many 0's in gradient, skipping this gradient and""increasing loss scaling. If this warning persists, ""check that your learning rate isn't too low")TRICYCLE_CONTEXT.loss_scale_factor*=2self.logger.warn(f"New scaling factor: {TRICYCLE_CONTEXT.loss_scale_factor}")returntensorifTRICYCLE_CONTEXT.use_mixed_precision:tensor.array-=combined_grad.astype(xp.float32)tensor.grad.array.fill(0)returntensor
def__call__(self,tensor:Tensor)->Tensor:""" Apply the AdamW optimisation to the given tensor. Args: tensor (Tensor): The tensor to optimise. Returns: Tensor: The optimised tensor. """returnself._reset_grad(self.update_weight(tensor))