Source code for tricycle.ops

"""Operations module for tensor manipulations.

This module contains various operations that can be applied to tensors,
including repeat, split, reshape, and mean operations.
"""

from abc import abstractmethod
from typing import Sequence

from numpy.typing import ArrayLike

from tricycle.context import TRICYCLE_CONTEXT
from tricycle.einsum import Einsum, Subscript
from tricycle.tensor import Tensor


[docs] class Op: """Base class for operations.""" _out: ArrayLike | None = None def __call__(self, *args, **kwargs) -> Tensor: """Call the forward method of the operation. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: Tensor: The result of the forward operation. """ return self.forward(*args, **kwargs)
[docs] @abstractmethod def forward(self, *args, **kwargs) -> Tensor: """Abstract method for the forward pass of the operation. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Raises: NotImplementedError: This method should be implemented by subclasses. Returns: Tensor: The result of the forward operation. """ raise NotImplementedError()
[docs] class Repeat(Op): """Operation to repeat a tensor along its final axis."""
[docs] def forward(self, tensor: Tensor, repeats: int): """Repeat a tensor along its final axis. This is done by multiplying with a ones tensor the same shape as the desired output. Args: tensor (Tensor): The input tensor to repeat. repeats (int): The number of times to repeat the tensor. Returns: Tensor: The repeated tensor. """ xp = tensor.xp subscript = Subscript("...,...a->...a") new_shape = tensor.shape + (repeats,) ones = Tensor( xp.ones(new_shape), is_batched=tensor.is_batched, requires_grad=False, ) return Einsum(subscript)(tensor, ones)
[docs] class Split(Op): """Operation to split a tensor along an axis.""" _indices: tuple[int] _axis: int _n_splits: int _grad: list[ArrayLike]
[docs] def back_fn(self, grad: Tensor, idx: int) -> Tensor: """The backwards operation for a split operation. Produces a tensor of zeros the same shape as the input except in the section that was split. Args: grad (Tensor): The gradient tensor. idx (int): The index of the split. Returns: Tensor: The gradient for the input tensor. Example: >>> result = split([1,2,3,4], 2) >>> result [tensor([1, 2]), tensor([3, 4])] # set an arbitrary derivative for first split >>> result[0].grad = Tensor([1,1]) >>> undo_split(result[0].grad) [1, 1, 0, 0] """ xp = grad.xp self._grad[idx] = xp.zeros(self._in_shape) # TODO: this loop is really slow and should be replaced indices = [] for i in range(self._grad[idx].ndim): if i == self._axis % self._grad[idx].ndim: step = self._in_shape[i] // self._n_splits start = step * idx end = step * (idx + 1) indices.append(slice(start, end)) else: indices.append(slice(None)) self._grad[idx][tuple(indices)] = grad.array result = Tensor(self._grad[idx]) result.is_batched = grad.is_batched return result
[docs] def forward( self, tensor: Tensor, n_splits: int, axis: int = -1 ) -> Sequence[Tensor]: """Split a tensor along an axis into n_splits partitions. Args: tensor (Tensor): The input tensor to split. n_splits (int): The number of splits to make. axis (int, optional): The axis along which to split. Defaults to -1. Returns: Sequence[Tensor]: A sequence of split tensors. """ xp = tensor.xp assert isinstance(n_splits, int) self._out = xp.split(tensor.array, n_splits, axis=axis) self._in_shape = tensor.shape self._axis = axis self._n_splits = n_splits self._grad = [None] * n_splits # TODO: this loop is really slow and should be replaced results = [] for idx, result in enumerate(self._out): # the back_fn depends on index so we need to # dynamically create this function def back_fn(grad, idx=idx): return self.back_fn(grad, idx=idx) result = Tensor(result) result.back_fns = (back_fn,) result.args = (tensor,) result.is_batched = tensor.is_batched results.append(result) return results
[docs] class Reshape(Op): """Operation to reshape a tensor.""" _original_shape: Sequence[int]
[docs] def back_fn(self, grad: Tensor) -> Tensor: """Backward function for the reshape operation. Args: grad (Tensor): The gradient tensor. Returns: Tensor: The gradient reshaped to the original shape. """ xp = grad.xp self._grad = xp.reshape(grad.array, self._original_shape) return Tensor(array=self._grad, is_batched=grad.is_batched)
[docs] def forward(self, tensor: Tensor, shape: Sequence[int]) -> Tensor: """Reshape a tensor. The new shape needs to have the same number of elements as the original, but can have any number of dimensions. Args: tensor (Tensor): The input tensor to reshape. shape (Sequence[int]): The new shape for the tensor. Returns: Tensor: The reshaped tensor. """ xp = tensor.xp # if the tensor is batched, don't include the first dimension in # the reshape if tensor.is_batched: shape = [tensor.shape[0]] + list(shape) self._out = xp.reshape(tensor.array, shape) self._original_shape = tensor.shape return Tensor( array=self._out, args=(tensor,), back_fns=(self.back_fn,), name="reshape", is_batched=tensor.is_batched, )
[docs] class Mean(Op): """Operation to find the mean of a tensor."""
[docs] def backward(self, grad: Tensor) -> Tensor: """Backward function for the mean operation. Args: grad (Tensor): The gradient tensor. Returns: Tensor: The gradient for the input tensor. """ xp = grad.xp result = xp.full(self._in_shape, self.divisor) out = grad.array * result return Tensor(out, is_batched=self._is_batched)
[docs] def forward(self, tensor: Tensor) -> Tensor: """Find the mean of a tensor. Args: tensor (Tensor): The input tensor. Returns: Tensor: A tensor containing the mean value. """ xp = tensor.xp self._is_batched = tensor.is_batched self._in_shape = tensor.shape # we can overflow here with large arrays so we'll use full precision if TRICYCLE_CONTEXT.use_mixed_precision: tensor.array = tensor.array.astype(xp.float32) self.divisor = 1 / xp.prod(tensor.shape) if tensor.shape else 1 out = tensor.array.sum() * self.divisor if TRICYCLE_CONTEXT.use_mixed_precision: out = out.astype(xp.float16) return Tensor( out, name="mean", back_fns=(self.backward,), args=(tensor,) )