Source code for tricycle.tensor

"""
The core of Tricycle is the Tensor object, which is implemented in this file.

A Tensor is a wrapper around a numpy/cupy array that adds automatic
differentiation.

The autodiff algorithm itself can be found in `Tensor.backward`.

This file also contains a few other helpful functions like `batch` which
converts tensors to batched tensors.
"""

import logging
import numbers
import uuid
from typing import TYPE_CHECKING, List, Optional, Sequence, Union

import numpy as np
from numpy.typing import ArrayLike

from tricycle import GPU_ENABLED
from tricycle.context import TRICYCLE_CONTEXT
from tricycle.exceptions import GPUDisabledException
from tricycle.weakset import WeakSet

if TYPE_CHECKING:
    from tricycle.ops import Op

logger = logging.getLogger(__name__)

DEFAULT_DTYPE = np.float32


[docs] class Tensor: """ An N-dimensional grid of numbers. This is implemented as a subclass of a standard numpy array. Attributes: _id (int): Unique identifier for the tensor. _parents (set[Tensor] | None): Parent tensors in the computation graph. array (ArrayLike): The underlying numpy/cupy array. args (tuple[Tensor, ...] | None): Arguments used to create this tensor. back_fns (tuple[Op, ...] | None): Backward functions for gradient computation. grad (Optional[Tensor]): Gradient of this tensor. name (Optional[str]): Name of the tensor. requires_grad (bool): Whether this tensor requires gradient computation. is_batched (bool): Whether this tensor is batched. """ def __init__( self, array: ArrayLike, requires_grad: bool = True, is_batched: bool = False, args: tuple["Tensor", ...] | None = None, back_fns: tuple["Op", ...] | None = None, dtype: np.typing.DTypeLike = None, name: str | None = None, _id: int | None = None, ): """ Initializes a new Tensor object. Args: array (ArrayLike): The underlying numpy/cupy array. requires_grad (bool, optional): Whether this tensor requires gradient computation. Defaults to True. is_batched (bool, optional): Whether this tensor is batched. Defaults to False. args (tuple[Tensor, ...] | None, optional): Arguments used to create this tensor. Defaults to None. back_fns (tuple[Op, ...] | None, optional): Backward functions for gradient computation. Defaults to None. dtype (np.typing.DTypeLike, optional): Data type of the tensor. Defaults to None. name (str | None, optional): Name of the tensor. Defaults to None. _id (int | None, optional): Unique identifier for the tensor. Defaults to None. """ if isinstance(array, Tensor): self = array return self._id = _id or uuid.uuid4().int self._parents = None if GPU_ENABLED: import cupy if isinstance(array, (np.ndarray, cupy.ndarray)): self.array = array else: self.array = np.array(array) else: self.array = np.array(array) if dtype is None: if TRICYCLE_CONTEXT.use_mixed_precision: dtype = np.float16 else: dtype = DEFAULT_DTYPE self.array = self.array.astype(dtype) self.grad = None self.requires_grad = requires_grad self.is_batched = is_batched self.args = args self.back_fns = back_fns self.name = name def _attach_parents(self): """ Traverses through the graph, labelling each tensor with the tensors that are direct parents to it in the graph. This is done to enable traversal through the graph later in topological order. """ stack: list["Tensor"] = [self] while stack: node = stack.pop() if not node.args: continue for arg in node.args: if not arg.requires_grad: continue if arg._parents is None: # if we use a set, we get a circular reference # which can't be garbage collected, leading to a memory # leak so we need to do a weakref to avoid the circular # reference arg._parents = WeakSet() # if a node has a parent we haven't visited yet, store it if node not in arg._parents: stack.append(arg) arg._parents.add(node) def _calculate_gradients(self, clip: float | None = None): """ Calculates gradients for the computation graph. This method implements the backpropagation algorithm, traversing the graph from the output to the inputs and applying the chain rule to compute gradients. Args: clip (float | None, optional): Maximum absolute value for gradient clipping. Defaults to None. """ self.grad = Tensor( self.xp.ones(self.array.shape, dtype=self.dtype), requires_grad=False, is_batched=self.is_batched, ) stack: list["Tensor"] = [self] while stack: node = stack.pop() # if we have reached an input, we're done along this path if node.args is None or node.back_fns is None: continue for arg, back_fns in zip(node.args, node.back_fns): # if we reach a tensor that does not need gradient computation # (e.g a constant) then we're done along this path if not arg.requires_grad: continue if arg._parents is None: raise ValueError( "arg.parents is None. Parents must be attached", "before calculating gradients. Did you forget to ", "call _attach_parents?", ) # already visited along this edge, dont do it again if node not in arg._parents: continue arg._parents.remove(node) try: # actuall calculate gradient for this node grad = back_fns(node.grad) # gradient clipping # TODO: allow clipping by norm instead of just by value if clip is not None: grad.array = grad.xp.clip(grad.array, -clip, clip) # add current gradient to any gradients we have already # calculated for this node if arg.grad is None: arg.grad = grad else: arg.grad.array += grad.array except Exception as e: raise e # only move to a new node if we have been to all of its parents if len(arg._parents) == 0: # get rid of the weakref once we're done with a node so we # can pickle the model. Weakrefs can't be pickled arg._parents = None stack.append(arg)
[docs] def backward(self, clip: float | None = None): """ Performs a backward pass through the graph, calculating the gradient for each parameter. Args: clip (float | None, optional): Maximum absolute value for gradient clipping. Defaults to None. """ self._attach_parents() self._calculate_gradients(clip=clip)
def __hash__(self) -> int: return self._id def __add__(self, other: Union[float, "Tensor"]) -> "Tensor": """ Implements addition for Tensor objects. Args: other (Union[float, Tensor]): The value to add to this tensor. Returns: Tensor: The result of the addition. Raises: NotImplementedError: If addition is not supported between the given types. """ if isinstance(other, numbers.Number): from tricycle.unary import UnaryAdd return UnaryAdd()(self, other) elif isinstance(other, Tensor): from tricycle.binary import BinaryAdd return BinaryAdd()(self, other) else: raise NotImplementedError( f"Cannot add {type(self)} and {type(other)}" ) def __radd__(self, other): return self + other def __iadd__(self, other): return self + other def __sub__(self, other): """ Implements subtraction for Tensor objects. Args: other (Union[float, Tensor]): The value to subtract from this tensor. Returns: Tensor: The result of the subtraction. Raises: NotImplementedError: If subtraction is not supported between the given types. """ if isinstance(other, self.xp.ndarray) and not isinstance( other, Tensor ): other = Tensor(other) if self.xp.isscalar(other): from tricycle.unary import UnarySubtract return UnarySubtract()(self, other) elif isinstance(other, Tensor): from tricycle.binary import BinarySubtract return BinarySubtract()(self, other) else: raise NotImplementedError( f"Cannot sub {type(self)} and {type(other)}" ) def __rsub__(self, other): return -(self - other) def __isub__(self, other): return self.__sub__(other) def __mul__(self, other): """ Implements multiplication for Tensor objects. Args: other (Union[float, Tensor]): The value to multiply with this tensor. Returns: Tensor: The result of the multiplication. Raises: NotImplementedError: If multiplication is not supported between the given types. """ if isinstance(other, self.xp.ndarray) and not isinstance( other, Tensor ): other = Tensor(other) if self.xp.isscalar(other) or other.shape == (): from tricycle.unary import UnaryMultiply return UnaryMultiply()(self, other) elif isinstance(other, Tensor): from tricycle.binary import BinaryMultiply return BinaryMultiply()(self, other) else: raise NotImplementedError( f"Cannot mul {type(self)} and {type(other)}" ) def __rmul__(self, other): return self * other def __imul__(self, other): return self * other def __neg__(self): return self * -1 def __truediv__(self, other): """ Implements true division for Tensor objects. Args: other (Union[float, Tensor]): The value to divide this tensor by. Returns: Tensor: The result of the division. Raises: NotImplementedError: If division is not supported between the given types. """ if self.xp.isscalar(other): from tricycle.unary import UnaryMultiply return UnaryMultiply()(self, 1 / other) elif isinstance(other, Tensor): from tricycle.binary import BinaryDivide return BinaryDivide()(self, other) else: raise NotImplementedError( f"Cannot divide {type(self)} and {type(other)}" ) def __rtruediv__(self, other): if self.xp.isscalar(other): from tricycle.unary import UnaryDivide return UnaryDivide()(other, self) elif isinstance(other, Tensor): from tricycle.binary import BinaryDivide return BinaryDivide()(other, self) def __itruediv__(self, other): return self / other def __pow__(self, other) -> "Tensor": """ Implements exponentiation for Tensor objects. Args: other (Union[float, Tensor]): The exponent. Returns: Tensor: The result of the exponentiation. Raises: NotImplementedError: If exponentiation is not supported between the given types. """ if isinstance(other, self.xp.ndarray) and not isinstance( other, Tensor ): other = Tensor(other) if self.xp.isscalar(other): from tricycle.unary import UnaryPower return UnaryPower()(self, other) elif isinstance(other, Tensor): raise NotImplementedError( "Cannot power two tensors of shape: " f"{self.shape}, {other.shape}" ) else: raise NotImplementedError( f"Cannot power {type(self)} and {type(other)}" ) def __lt__(self, other): if isinstance(other, Tensor): return Tensor(self.array < other.array) return Tensor(self.array < other) def __le__(self, other): if isinstance(other, Tensor): return Tensor(self.array <= other.array) return Tensor(self.array <= other) def __eq__(self, other): if isinstance(other, Tensor): if other._id == self._id: return Tensor(True) return Tensor(self.xp.array_equal(self.array == other.array)) return Tensor(self.array == other) def __ne__(self, other): if isinstance(other, Tensor): return Tensor(self.array != other.array) return Tensor(self.array != other) def __gt__(self, other): if isinstance(other, Tensor): return Tensor(self.array > other.array) return Tensor(self.array > other) def __ge__(self, other): if isinstance(other, Tensor): return Tensor(self.array >= other.array) return Tensor(self.array >= other) def __repr__(self): name = f", name={self.name}" if self.name is not None else "" return f"Tensor({self.array.__str__()}{name})" def __getitem__(self, idx): return Tensor(self.array[idx], requires_grad=self.requires_grad) def __setitem__(self, idx, value): self.array[idx] = value @property def xp(self): """ Returns the appropriate array library (numpy or cupy) for the tensor. Returns: module: The array library (numpy or cupy). """ return select_backend(self.array)
[docs] def einsum(self, subscript: str) -> "Tensor": """ Performs an einsum operation on the tensor. Args: subscript (str): The einsum subscript string. Returns: Tensor: The result of the einsum operation. """ from tricycle.einsum import Einsum return Einsum(subscript)(self)
[docs] def repeat(self, n_repeats: int) -> "Tensor": """ Repeats the tensor. Args: n_repeats (int): The number of times to repeat the tensor. Returns: Tensor: The repeated tensor. """ from tricycle.ops import Repeat return Repeat()(self, n_repeats)
@property def shape(self) -> Sequence[int]: """ Returns the shape of the tensor. Returns: Sequence[int]: The shape of the tensor. """ return self.array.shape @property def ndim(self) -> int: """ Returns the number of dimensions of the tensor. Returns: int: The number of dimensions. """ return self.array.ndim @property def dtype(self) -> np.dtype: """ Returns the data type of the tensor. Returns: np.dtype: The data type of the tensor. """ return self.array.dtype
[docs] def reshape(self, shape: Sequence[int]) -> "Tensor": """ Reshapes the tensor to the given shape. Args: shape (Sequence[int]): The new shape for the tensor. Returns: Tensor: The reshaped tensor. """ from tricycle.ops import Reshape return Reshape()(self, shape)
[docs] def split(self, n_splits: int, axis: int = -1) -> List["Tensor"]: """ Splits the tensor into multiple sub-tensors. Args: n_splits (int): The number of splits to perform. axis (int, optional): The axis along which to split. Defaults to -1. Returns: List[Tensor]: A list of split tensors. """ from tricycle.ops import Split return Split()(self, n_splits=n_splits, axis=axis)
[docs] def mean(self) -> "Tensor": """ Computes the mean of all elements in the tensor. Returns: Tensor: A new tensor containing the mean value. """ from tricycle.ops import Mean return Mean()(self)
[docs] def sum(self) -> "Tensor": """ Computes the sum of all elements in the tensor. Returns: Tensor: A new tensor containing the sum. """ from tricycle.unary import UnarySum return UnarySum()(self)
[docs] def close_to( self, other: Union["Tensor", ArrayLike, float, int], equal_nan=False, rtol=1e-4, **kwargs, ) -> bool: """ Checks if this tensor is close to another tensor or value within a tolerance. Args: other (Union[Tensor, ArrayLike, float, int]): The tensor or value to compare against. equal_nan (bool, optional): Whether to consider NaN values as equal. Defaults to False. rtol (float, optional): The relative tolerance parameter. Defaults to 1e-4. **kwargs: Additional keyword arguments to pass to numpy.allclose or cupy.allclose. Returns: bool: True if the tensors are close, False otherwise. """ if not isinstance(other, Tensor): return self.xp.allclose( self.array, self.xp.array(other), equal_nan=equal_nan, rtol=rtol, **kwargs, ) return self.xp.allclose( self.array, other.array, equal_nan=equal_nan, rtol=rtol, **kwargs )
[docs] def to_batched(self): """ Treats this tensor as a batch of tensors. Returns: Tensor: A new batched tensor. """ from tricycle.unary import Batch return Batch()(self)
[docs] def from_batched(self): """ Treats a batched tensor as a normal, non-batched, tensor. Returns: Tensor: A new non-batched tensor. """ from tricycle.unary import Unbatch return Unbatch()(self)
@property def on_gpu(self): """ Checks if the tensor is currently on the GPU. Returns: bool: True if the tensor is on the GPU, False otherwise. """ if not GPU_ENABLED: return False import cupy return isinstance(self.array, cupy.ndarray)
[docs] def to_gpu(self, device: int = 0): """ Moves this tensor to the GPU, if cupy is enabled. Args: device (int, optional): The GPU device number. Defaults to 0. Returns: Tensor: The tensor moved to the GPU. Raises: GPUDisabledException: If CuPY is not enabled. """ if not GPU_ENABLED: raise GPUDisabledException( "Cannot move tensor to GPU because CuPY is not enabled" ) import cupy cupy.cuda.Device(device).use() self.array = cupy.asarray(self.array) return self
[docs] def from_gpu(self): """ Moves this tensor from the GPU to CPU. Returns: Tensor: The tensor moved to the CPU. Raises: GPUDisabledException: If CuPY is not enabled. """ if not GPU_ENABLED: raise GPUDisabledException( "Cannot move tensor from GPU because CuPY is not enabled" ) import cupy self.array = cupy.asnumpy(self.array) return self
[docs] def zero_grad(self): """ Removes any gradients or references to other tensors. Returns: Tensor: The tensor with gradients and references cleared. """ self.grad = None self.args = None self.back_fns = None return self
[docs] def numpy(self): """ Returns the underlying array as a numpy array. Returns: np.ndarray: The tensor data as a numpy array. """ if not GPU_ENABLED: return self.array import cupy return cupy.asnumpy(self.array) if self.on_gpu else self.array
[docs] def select_backend(*tensors: Tensor | np.ndarray | ArrayLike): """ Given some tensors, if any of them are on the GPU, return the cupy backend. Otherwise default to the numpy backend. Args: *tensors: Variable number of tensors or arrays to check. Returns: module: The appropriate backend module (numpy or cupy). """ if not GPU_ENABLED: return np import cupy return cupy.get_array_module(*tensors)