Source code for tricycle.binary

"""
Binary operations for the Tricycle framework.

This module contains binary operations that can be applied to tensors of the same shape.
These operations include element-wise addition, subtraction, multiplication, division,
and comparison operations like maximum and minimum.

Note:
    In Tricycle, binary operations are only allowed on matrices of the same shape to
    simplify gradient computations.
"""

from numpy.typing import ArrayLike

from tricycle.ops import Einsum, Op
from tricycle.tensor import Tensor, select_backend
from tricycle.unary import UnaryDivide, nothing
from tricycle.utils import shapes_match


[docs] class BinaryAdd(Op): """Element-wise addition of two tensors. This class implements the forward pass for element-wise addition of two tensors. Attributes: _out: The output of the forward pass. """
[docs] def forward(self, tensor_1: Tensor, tensor_2: Tensor) -> Tensor: """Add two tensors element-wise. Args: tensor_1: First input tensor. tensor_2: Second input tensor. Returns: A Tensor representing the element-wise sum of the input tensors. Raises: AssertionError: If the shapes of the input tensors do not match. """ xp = select_backend(tensor_1.array, tensor_2.array) assert shapes_match(tensor_1, tensor_2) self._out = xp.add(tensor_1.array, tensor_2.array) return Tensor( self._out, args=(tensor_1, tensor_2), back_fns=(nothing, nothing), name="badd", is_batched=tensor_1.is_batched or tensor_2.is_batched, )
[docs] class BinarySubtract(Op): """Element-wise subtraction of two tensors. This class implements the forward and backward passes for element-wise subtraction of two tensors. Attributes: _grad: The gradient of the backward pass for the second tensor. _out: The output of the forward pass. """
[docs] def back_fn_2(self, grad: Tensor) -> Tensor: """Compute the gradient for the second tensor in the subtraction. Args: grad: The gradient tensor. Returns: A Tensor representing the gradient for the second tensor. """ self._grad = -grad.array return Tensor(array=self._grad, is_batched=grad.is_batched)
[docs] def forward(self, tensor_1: Tensor, tensor_2: Tensor) -> Tensor: """Subtract one tensor from another element-wise. The two tensors must have the same shape. Args: tensor_1: First input tensor. tensor_2: Second input tensor to be subtracted from the first. Returns: A Tensor representing the element-wise difference of the input tensors. Raises: AssertionError: If the shapes of the input tensors do not match. """ xp = select_backend(tensor_1.array, tensor_2.array) assert shapes_match(tensor_1, tensor_2) self._out = xp.subtract(tensor_1.array, tensor_2.array) return Tensor( self._out, args=(tensor_1, tensor_2), back_fns=(nothing, self.back_fn_2), name="bsub", is_batched=tensor_1.is_batched or tensor_2.is_batched, )
[docs] class BinaryMultiply(Op): """Element-wise multiplication of two tensors. This class implements the forward pass for element-wise multiplication of two tensors. """
[docs] def forward(self, tensor_1: Tensor, tensor_2: Tensor) -> Tensor: """Multiply the elements of two tensors together, element-wise. The two tensors must have the same shape. Args: tensor_1: First input tensor. tensor_2: Second input tensor. Returns: A Tensor representing the element-wise product of the input tensors. Raises: AssertionError: If the shapes of the input tensors do not match. """ assert shapes_match(tensor_1, tensor_2) result = Einsum("...,...->...")(tensor_1, tensor_2) result.name = "bmul" return result
[docs] class BinaryDivide(Op): """Element-wise division of two tensors. This class implements the forward pass for element-wise division of two tensors. TODO: we should probably fuse these into a single op """
[docs] def forward(self, tensor_1: Tensor, tensor_2: Tensor) -> Tensor: """Divide the elements of two tensors together, element-wise. The two tensors must have the same shape. Args: tensor_1: First input tensor (numerator). tensor_2: Second input tensor (denominator). Returns: A Tensor representing the element-wise division of the input tensors. Raises: AssertionError: If the shapes of the input tensors do not match. """ mul = BinaryMultiply() div = UnaryDivide() return mul(tensor_1, div(1, tensor_2))
[docs] class BinaryMax(Op): """Element-wise maximum of two tensors. This class implements the forward and backward passes for element-wise maximum of two tensors. Attributes: _is_bigger_1: Boolean array indicating where the first tensor is larger. _is_bigger_2: Boolean array indicating where the second tensor is larger or equal. _out: The output of the forward pass. _grad_1: The gradient for the first tensor. _grad_2: The gradient for the second tensor. """
[docs] def back_fn_1(self, grad: Tensor) -> Tensor: """Compute the gradient for the first tensor in the maximum operation. Args: grad: The gradient tensor. Returns: A Tensor representing the gradient for the first tensor. """ self._grad_1 = grad.array * self._is_bigger_1 return Tensor(array=self._grad_1, is_batched=grad.is_batched)
[docs] def back_fn_2(self, grad: Tensor) -> Tensor: """Compute the gradient for the second tensor in the maximum operation. Args: grad: The gradient tensor. Returns: A Tensor representing the gradient for the second tensor. """ self._grad_2 = grad.array * self._is_bigger_2 return Tensor(array=self._grad_2, is_batched=grad.is_batched)
[docs] def forward(self, tensor_1: Tensor, tensor_2: Tensor) -> Tensor: """Compare two tensors element-wise, returning the maximum of each pair of elements. The two tensors must have the same shape. If elements are equal, return the first. Args: tensor_1: First input tensor. tensor_2: Second input tensor. Returns: A Tensor representing the element-wise maximum of the input tensors. Raises: AssertionError: If the shapes of the input tensors do not match. """ xp = select_backend(tensor_1.array, tensor_2.array) assert shapes_match(tensor_1, tensor_2) self._out = xp.maximum(tensor_1.array, tensor_2.array) self._is_bigger_1 = tensor_1.array > tensor_2.array self._is_bigger_2 = tensor_1.array <= tensor_2.array return Tensor( self._out, args=(tensor_1, tensor_2), back_fns=(self.back_fn_1, self.back_fn_2), name="bmax", is_batched=tensor_1.is_batched or tensor_2.is_batched, )
[docs] class BinaryMin(Op): """Element-wise minimum of two tensors. This class implements the forward and backward passes for element-wise minimum of two tensors. Attributes: _is_smaller_1: Boolean array indicating where the first tensor is smaller. _is_smaller_2: Boolean array indicating where the second tensor is smaller or equal. _out: The output of the forward pass. _grad_1: The gradient for the first tensor. _grad_2: The gradient for the second tensor. """
[docs] def back_fn_1(self, grad: Tensor) -> Tensor: """Compute the gradient for the first tensor in the minimum operation. Args: grad: The gradient tensor. Returns: A Tensor representing the gradient for the first tensor. """ self._grad_1 = grad.array * self._is_smaller_1 return Tensor(array=self._grad_1, is_batched=grad.is_batched)
[docs] def back_fn_2(self, grad: Tensor) -> Tensor: """Compute the gradient for the second tensor in the minimum operation. Args: grad: The gradient tensor. Returns: A Tensor representing the gradient for the second tensor. """ self._grad_2 = grad.array * self._is_smaller_2 return Tensor(array=self._grad_2, is_batched=grad.is_batched)
[docs] def forward(self, tensor_1: Tensor, tensor_2: Tensor) -> Tensor: """Compare two tensors element-wise, returning the minimum of each pair of elements. The two tensors must have the same shape. If elements are equal, return the first. Args: tensor_1: First input tensor. tensor_2: Second input tensor. Returns: A Tensor representing the element-wise minimum of the input tensors. Raises: AssertionError: If the shapes of the input tensors do not match. """ xp = select_backend(tensor_1.array, tensor_2.array) assert shapes_match(tensor_1, tensor_2) self._out = xp.minimum(tensor_1.array, tensor_2.array) self._is_smaller_1 = tensor_1.array < tensor_2.array self._is_smaller_2 = tensor_1.array >= tensor_2.array return Tensor( self._out, args=(tensor_1, tensor_2), back_fns=(self.back_fn_1, self.back_fn_2), name="bmin", is_batched=tensor_1.is_batched or tensor_2.is_batched, )