optimisers

Optimisers for gradient-based optimisation.

This module contains various optimiser classes that can be used for gradient-based optimisation of tensors.

class AdamW(learning_rate=0.001, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.01)[source]

Bases: Optimiser

AdamW optimiser.

This optimiser implements the AdamW algorithm, which is Adam with weight decay.

learning_rate

The learning rate for the optimiser.

Type:

float

betas

The exponential decay rates for the moment estimates.

Type:

tuple

eps

A small constant for numerical stability.

Type:

float

weight_decay

The weight decay factor.

Type:

float

timestep

The current time step.

Type:

int

momentum

Store for first moment estimates.

Type:

dict

square_momentum

Store for second moment estimates.

Type:

dict

step()[source]

Increase the time step.

This method should be called after each optimisation step.

update_weight(tensor)[source]

Perform a weight update on a tensor using the AdamW algorithm.

Parameters:

tensor (Tensor) – The tensor to update.

Returns:

The updated tensor.

Return type:

Tensor

class Optimiser[source]

Bases: object

Base class for optimisers.

class StochasticGradientDescent(learning_rate, weight_decay=None, momentum=None, logger=<Logger tricycle.optimisers (WARNING)>)[source]

Bases: Optimiser

Stochastic Gradient Descent (SGD) optimiser.

This optimiser implements SGD with optional weight decay and momentum.

Parameters:
  • learning_rate (float)

  • weight_decay (float | None)

  • momentum (float | None)

learning_rate

The learning rate for the optimiser.

Type:

float

weight_decay

The weight decay factor.

Type:

float | None

momentum

The momentum factor.

Type:

float | None

logger

The logger instance.

momentum_store

Store for momentum values.

Type:

dict

update_weight(tensor)[source]

Perform a gradient update on a tensor.

This method updates the tensor’s weights using the calculated gradients, optionally applying weight decay and momentum.

Parameters:

tensor (Tensor) – The tensor to update.

Returns:

The updated tensor.

Return type:

Tensor