ops

Operations module for tensor manipulations.

This module contains various operations that can be applied to tensors, including repeat, split, reshape, and mean operations.

class Mean[source]

Bases: Op

Operation to find the mean of a tensor.

backward(grad)[source]

Backward function for the mean operation.

Parameters:

grad (Tensor) – The gradient tensor.

Returns:

The gradient for the input tensor.

Return type:

Tensor

forward(tensor)[source]

Find the mean of a tensor.

Parameters:

tensor (Tensor) – The input tensor.

Returns:

A tensor containing the mean value.

Return type:

Tensor

class Op[source]

Bases: object

Base class for operations.

abstract forward(*args, **kwargs)[source]

Abstract method for the forward pass of the operation.

Parameters:
  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Raises:

NotImplementedError – This method should be implemented by subclasses.

Returns:

The result of the forward operation.

Return type:

Tensor

class Repeat[source]

Bases: Op

Operation to repeat a tensor along its final axis.

forward(tensor, repeats)[source]

Repeat a tensor along its final axis.

This is done by multiplying with a ones tensor the same shape as the desired output.

Parameters:
  • tensor (Tensor) – The input tensor to repeat.

  • repeats (int) – The number of times to repeat the tensor.

Returns:

The repeated tensor.

Return type:

Tensor

class Reshape[source]

Bases: Op

Operation to reshape a tensor.

back_fn(grad)[source]

Backward function for the reshape operation.

Parameters:

grad (Tensor) – The gradient tensor.

Returns:

The gradient reshaped to the original shape.

Return type:

Tensor

forward(tensor, shape)[source]

Reshape a tensor.

The new shape needs to have the same number of elements as the original, but can have any number of dimensions.

Parameters:
  • tensor (Tensor) – The input tensor to reshape.

  • shape (Sequence[int]) – The new shape for the tensor.

Returns:

The reshaped tensor.

Return type:

Tensor

class Split[source]

Bases: Op

Operation to split a tensor along an axis.

back_fn(grad, idx)[source]

The backwards operation for a split operation.

Produces a tensor of zeros the same shape as the input except in the section that was split.

Parameters:
  • grad (Tensor) – The gradient tensor.

  • idx (int) – The index of the split.

Returns:

The gradient for the input tensor.

Return type:

Tensor

Example

>>> result = split([1,2,3,4], 2)
>>> result
[tensor([1, 2]), tensor([3, 4])]
# set an arbitrary derivative for first split
>>> result[0].grad = Tensor([1,1])
>>> undo_split(result[0].grad)
[1, 1, 0, 0]
forward(tensor, n_splits, axis=-1)[source]

Split a tensor along an axis into n_splits partitions.

Parameters:
  • tensor (Tensor) – The input tensor to split.

  • n_splits (int) – The number of splits to make.

  • axis (int, optional) – The axis along which to split. Defaults to -1.

Returns:

A sequence of split tensors.

Return type:

Sequence[Tensor]