utils¶
Utility functions and classes for the Tricycle project.
This module contains various utility functions and classes used throughout the Tricycle project, including dataset handling, mixed precision training, tensor shape matching, and performance logging.
- class Dataset[source]¶
Bases:
object
Abstract base class for datasets.
This class defines the interface for dataset objects used in the project. Subclasses should implement the __len__ and __getitem__ methods.
- class UseMixedPrecision(initial_loss_scale_factor=128)[source]¶
Bases:
object
Context manager for enabling mixed precision training.
This class provides a context manager that enables mixed precision training when entered and disables it when exited.
- Parameters:
initial_loss_scale_factor (int) – The initial loss scale factor for mixed precision training. Defaults to 128.
- log_memory_and_time(stage, path=PosixPath('memory.log'))[source]¶
Logs the current GPU memory usage and timestamp to a file.
- Parameters:
stage (str) – A string describing the current stage of execution.
path (Path) – The path to the log file. Defaults to “memory.log”.
- Raises:
GPUDisabledException – If GPU is not enabled.
- optimal_n_tokens(model, config)[source]¶
Estimates the compute-optimal number of tokens to train on using Chinchilla scaling.
- Parameters:
- Returns:
A tuple containing the optimal number of tokens and steps.
- Return type:
tuple
- Reference:
- r_squared(actual_values, predicted_values)[source]¶
Calculates the R-squared metric.
- Parameters:
actual_values – The actual values.
predicted_values – The predicted values.
- Returns:
The R-squared value.
- Return type:
float