utils

Utility functions and classes for the Tricycle project.

This module contains various utility functions and classes used throughout the Tricycle project, including dataset handling, mixed precision training, tensor shape matching, and performance logging.

class Dataset[source]

Bases: object

Abstract base class for datasets.

This class defines the interface for dataset objects used in the project. Subclasses should implement the __len__ and __getitem__ methods.

class UseMixedPrecision(initial_loss_scale_factor=128)[source]

Bases: object

Context manager for enabling mixed precision training.

This class provides a context manager that enables mixed precision training when entered and disables it when exited.

Parameters:

initial_loss_scale_factor (int) – The initial loss scale factor for mixed precision training. Defaults to 128.

log_memory_and_time(stage, path=PosixPath('memory.log'))[source]

Logs the current GPU memory usage and timestamp to a file.

Parameters:
  • stage (str) – A string describing the current stage of execution.

  • path (Path) – The path to the log file. Defaults to “memory.log”.

Raises:

GPUDisabledException – If GPU is not enabled.

optimal_n_tokens(model, config)[source]

Estimates the compute-optimal number of tokens to train on using Chinchilla scaling.

Parameters:
  • model (GPT) – The GPT model.

  • config (GPTConfig) – The GPT configuration.

Returns:

A tuple containing the optimal number of tokens and steps.

Return type:

tuple

Reference:

https://arxiv.org/abs/2404.10102

r_squared(actual_values, predicted_values)[source]

Calculates the R-squared metric.

Parameters:
  • actual_values – The actual values.

  • predicted_values – The predicted values.

Returns:

The R-squared value.

Return type:

float

shapes_match(tensor_1, tensor_2)[source]

Checks if the shapes of two tensors match for binary operations.

Parameters:
  • tensor_1 (Tensor) – The first tensor to compare.

  • tensor_2 (Tensor) – The second tensor to compare.

Returns:

True if the shapes match, False otherwise.

Return type:

bool

Raises:

ValueError – If the shapes do not match.

smooth(iterable, factor)[source]

Applies exponential smoothing to an iterable.

Parameters:
  • iterable (Iterable) – The input iterable to smooth.

  • factor (float) – The smoothing factor.

Yields:

float – The smoothed values.