ops¶
Operations module for tensor manipulations.
This module contains various operations that can be applied to tensors, including repeat, split, reshape, and mean operations.
- class Mean[source]¶
Bases:
Op
Operation to find the mean of a tensor.
- class Op[source]¶
Bases:
object
Base class for operations.
- abstract forward(*args, **kwargs)[source]¶
Abstract method for the forward pass of the operation.
- Parameters:
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Raises:
NotImplementedError – This method should be implemented by subclasses.
- Returns:
The result of the forward operation.
- Return type:
- class Repeat[source]¶
Bases:
Op
Operation to repeat a tensor along its final axis.
- class Reshape[source]¶
Bases:
Op
Operation to reshape a tensor.
- class Split[source]¶
Bases:
Op
Operation to split a tensor along an axis.
- back_fn(grad, idx)[source]¶
The backwards operation for a split operation.
Produces a tensor of zeros the same shape as the input except in the section that was split.
- Parameters:
grad (Tensor) – The gradient tensor.
idx (int) – The index of the split.
- Returns:
The gradient for the input tensor.
- Return type:
Example
>>> result = split([1,2,3,4], 2) >>> result [tensor([1, 2]), tensor([3, 4])] # set an arbitrary derivative for first split >>> result[0].grad = Tensor([1,1]) >>> undo_split(result[0].grad) [1, 1, 0, 0]