"""Operations module for tensor manipulations.
This module contains various operations that can be applied to tensors,
including repeat, split, reshape, and mean operations.
"""
from abc import abstractmethod
from typing import Sequence
from numpy.typing import ArrayLike
from tricycle.context import TRICYCLE_CONTEXT
from tricycle.einsum import Einsum, Subscript
from tricycle.tensor import Tensor
[docs]
class Op:
"""Base class for operations."""
_out: ArrayLike | None = None
def __call__(self, *args, **kwargs) -> Tensor:
"""Call the forward method of the operation.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
Tensor: The result of the forward operation.
"""
return self.forward(*args, **kwargs)
[docs]
@abstractmethod
def forward(self, *args, **kwargs) -> Tensor:
"""Abstract method for the forward pass of the operation.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Raises:
NotImplementedError: This method should be implemented by subclasses.
Returns:
Tensor: The result of the forward operation.
"""
raise NotImplementedError()
[docs]
class Repeat(Op):
"""Operation to repeat a tensor along its final axis."""
[docs]
def forward(self, tensor: Tensor, repeats: int):
"""Repeat a tensor along its final axis.
This is done by multiplying with a ones tensor the same shape as the
desired output.
Args:
tensor (Tensor): The input tensor to repeat.
repeats (int): The number of times to repeat the tensor.
Returns:
Tensor: The repeated tensor.
"""
xp = tensor.xp
subscript = Subscript("...,...a->...a")
new_shape = tensor.shape + (repeats,)
ones = Tensor(
xp.ones(new_shape),
is_batched=tensor.is_batched,
requires_grad=False,
)
return Einsum(subscript)(tensor, ones)
[docs]
class Split(Op):
"""Operation to split a tensor along an axis."""
_indices: tuple[int]
_axis: int
_n_splits: int
_grad: list[ArrayLike]
[docs]
def back_fn(self, grad: Tensor, idx: int) -> Tensor:
"""The backwards operation for a split operation.
Produces a tensor of zeros the same shape as the input
except in the section that was split.
Args:
grad (Tensor): The gradient tensor.
idx (int): The index of the split.
Returns:
Tensor: The gradient for the input tensor.
Example:
>>> result = split([1,2,3,4], 2)
>>> result
[tensor([1, 2]), tensor([3, 4])]
# set an arbitrary derivative for first split
>>> result[0].grad = Tensor([1,1])
>>> undo_split(result[0].grad)
[1, 1, 0, 0]
"""
xp = grad.xp
self._grad[idx] = xp.zeros(self._in_shape)
# TODO: this loop is really slow and should be replaced
indices = []
for i in range(self._grad[idx].ndim):
if i == self._axis % self._grad[idx].ndim:
step = self._in_shape[i] // self._n_splits
start = step * idx
end = step * (idx + 1)
indices.append(slice(start, end))
else:
indices.append(slice(None))
self._grad[idx][tuple(indices)] = grad.array
result = Tensor(self._grad[idx])
result.is_batched = grad.is_batched
return result
[docs]
def forward(
self, tensor: Tensor, n_splits: int, axis: int = -1
) -> Sequence[Tensor]:
"""Split a tensor along an axis into n_splits partitions.
Args:
tensor (Tensor): The input tensor to split.
n_splits (int): The number of splits to make.
axis (int, optional): The axis along which to split. Defaults to -1.
Returns:
Sequence[Tensor]: A sequence of split tensors.
"""
xp = tensor.xp
assert isinstance(n_splits, int)
self._out = xp.split(tensor.array, n_splits, axis=axis)
self._in_shape = tensor.shape
self._axis = axis
self._n_splits = n_splits
self._grad = [None] * n_splits
# TODO: this loop is really slow and should be replaced
results = []
for idx, result in enumerate(self._out):
# the back_fn depends on index so we need to
# dynamically create this function
def back_fn(grad, idx=idx):
return self.back_fn(grad, idx=idx)
result = Tensor(result)
result.back_fns = (back_fn,)
result.args = (tensor,)
result.is_batched = tensor.is_batched
results.append(result)
return results
[docs]
class Reshape(Op):
"""Operation to reshape a tensor."""
_original_shape: Sequence[int]
[docs]
def back_fn(self, grad: Tensor) -> Tensor:
"""Backward function for the reshape operation.
Args:
grad (Tensor): The gradient tensor.
Returns:
Tensor: The gradient reshaped to the original shape.
"""
xp = grad.xp
self._grad = xp.reshape(grad.array, self._original_shape)
return Tensor(array=self._grad, is_batched=grad.is_batched)
[docs]
def forward(self, tensor: Tensor, shape: Sequence[int]) -> Tensor:
"""Reshape a tensor.
The new shape needs to have the same number of elements
as the original, but can have any number of dimensions.
Args:
tensor (Tensor): The input tensor to reshape.
shape (Sequence[int]): The new shape for the tensor.
Returns:
Tensor: The reshaped tensor.
"""
xp = tensor.xp
# if the tensor is batched, don't include the first dimension in
# the reshape
if tensor.is_batched:
shape = [tensor.shape[0]] + list(shape)
self._out = xp.reshape(tensor.array, shape)
self._original_shape = tensor.shape
return Tensor(
array=self._out,
args=(tensor,),
back_fns=(self.back_fn,),
name="reshape",
is_batched=tensor.is_batched,
)
[docs]
class Mean(Op):
"""Operation to find the mean of a tensor."""
[docs]
def backward(self, grad: Tensor) -> Tensor:
"""Backward function for the mean operation.
Args:
grad (Tensor): The gradient tensor.
Returns:
Tensor: The gradient for the input tensor.
"""
xp = grad.xp
result = xp.full(self._in_shape, self.divisor)
out = grad.array * result
return Tensor(out, is_batched=self._is_batched)
[docs]
def forward(self, tensor: Tensor) -> Tensor:
"""Find the mean of a tensor.
Args:
tensor (Tensor): The input tensor.
Returns:
Tensor: A tensor containing the mean value.
"""
xp = tensor.xp
self._is_batched = tensor.is_batched
self._in_shape = tensor.shape
# we can overflow here with large arrays so we'll use full precision
if TRICYCLE_CONTEXT.use_mixed_precision:
tensor.array = tensor.array.astype(xp.float32)
self.divisor = 1 / xp.prod(tensor.shape) if tensor.shape else 1
out = tensor.array.sum() * self.divisor
if TRICYCLE_CONTEXT.use_mixed_precision:
out = out.astype(xp.float16)
return Tensor(
out, name="mean", back_fns=(self.backward,), args=(tensor,)
)