Source code for tricycle.einsum

"""
Einsum implementation for generalized matrix operations.

This module provides an implementation of the einsum operation, which is a
generalization of many matrix operations. It allows for flexible manipulation
of tensors using index notation.

Example usage:
    >>> a = Tensor([[1,2],[3,4]])
    >>> Einsum("ij->ji")(a)
    Tensor([[1. 3.]
     [2. 4.]], name=einsum ij->ji)

For more details on einsum operations, refer to the class and function
docstrings.
"""

import itertools
import re
from typing import Sequence

from tricycle.tensor import Tensor, select_backend


[docs] class Subscript: """ A string that defines an einsum operation. This class parses and represents the subscript notation used in einsum operations. Attributes: subscript (str): The original subscript string. inputs (list[list[str]]): Parsed input indices. output (list[str]): Parsed output indices. """ subscript: str inputs: list[list[str]] output: list[str] _index_pattern: re.Pattern = re.compile(r"(?:[A-Za-z]|(?:\.{3}))") def __init__(self, subscript: str): """ Initialize a Subscript object. Args: subscript (str): The einsum subscript string. """ self.subscript = subscript self.inputs, self.output = self.split()
[docs] def split(self) -> tuple[list[list[str]], list[str]]: """ Parse a subscripts string into a list of indices and a result. Returns: tuple: A tuple containing two elements: - list[list[str]]: Parsed input indices. - list[str]: Parsed output indices. """ indices, result = self.subscript.split("->") indices = indices.split(",") indices = [re.findall(self._index_pattern, idx) for idx in indices] result = re.findall(self._index_pattern, result) return indices, result
[docs] @staticmethod def join(indices: list[list[str]], result: list[str]) -> str: """ Join parsed indices and result back into a subscript string. Args: indices (list[list[str]]): Parsed input indices. result (list[str]): Parsed output indices. Returns: str: The joined subscript string. """ inputs_string = ",".join(["".join(idx) for idx in indices]) outputs_string = "".join(result) return f"{inputs_string}->{outputs_string}"
[docs] @classmethod def from_split(cls, indices: list[list[str]], result: list[str]): """ Create a Subscript object from split indices and result. Args: indices (list[list[str]]): Parsed input indices. result (list[str]): Parsed output indices. Returns: Subscript: A new Subscript object. """ return cls(cls.join(indices, result))
@property def unique_input_indices(self) -> set[str]: """ Get the set of unique input indices. Returns: set[str]: Set of unique input indices. """ all_inputs = itertools.chain(*self.inputs) return set(all_inputs) def __repr__(self): """ Return a string representation of the Subscript object. Returns: str: The subscript string. """ return self.subscript def __str__(self): """ Return a string representation of the Subscript object. Returns: str: The subscript string. """ return self.subscript
[docs] class EinsumBackOp: """ The backward operation for an einsum operation. This class represents the backward pass of an einsum operation, which is done by swapping the indices and tensors for an input with the output. Attributes: idx (int): The index of the input tensor for which this backward operation is defined. tensors (Sequence[Tensor]): The input tensors of the original einsum operation. subscript (Subscript): The subscript of the original einsum operation. left_tensors (Sequence[Tensor]): Tensors to the left of the current input in the original operation. right_tensors (Sequence[Tensor]): Tensors to the right of the current input in the original operation. combined_subscript (Subscript): The subscript for the backward operation. """ def __init__( self, idx: int, tensors: Sequence[Tensor], subscript: Subscript ): """ Initialize an EinsumBackOp object. Args: idx (int): The index of the input tensor for which this backward operation is defined. tensors (Sequence[Tensor]): The input tensors of the original einsum operation. subscript (Subscript): The subscript of the original einsum operation. """ self.idx = idx self.tensors = tensors self.subscript = subscript self.left_tensors, self.right_tensors = self._build_inputs() self.combined_subscript = self._build_subscript() def _build_inputs(self): """ Build the left and right tensor sequences for the backward operation. Returns: tuple: A tuple containing two elements: - Sequence[Tensor]: Left tensors. - Sequence[Tensor]: Right tensors. """ left_tensors = self.tensors[: self.idx] # Special case for the last index if self.idx < len(self.tensors) - 1: right_tensors = self.tensors[self.idx + 1 :] else: right_tensors = [] return left_tensors, right_tensors def _build_subscript(self): """ Build the subscript for the backward operation. Returns: Subscript: The subscript for the backward operation. """ left_subscript = self.subscript.inputs[: self.idx] # Special case for the last index if self.idx < len(self.tensors) - 1: right_subscript = self.subscript.inputs[self.idx + 1 :] else: right_subscript = [] combined_indices = [ *left_subscript, self.subscript.output, *right_subscript, ] return Subscript.from_split( combined_indices, self.subscript.inputs[self.idx] ) def __call__(self, tensor: Tensor): """ Build the backward function for einsum. This is done by swapping the indices and tensors for an input with the output. E.g "ij,jk->ik" with idx = 0 would become "ik,jk->ij" Args: tensor (Tensor): The gradient tensor from the previous layer. Returns: Tensor: The result of the backward einsum operation. """ combined_tensors = [*self.left_tensors, tensor, *self.right_tensors] return Einsum(self.combined_subscript)(*combined_tensors) def __repr__(self): """ Return a string representation of the EinsumBackOp object. Returns: str: A string representation of the object. """ return f"EinsumBackOp({self.combined_subscript})"
[docs] class Einsum: """ A class representing an einsum operation. This class encapsulates the logic for performing einsum operations on tensors, including handling of batched operations and backward passes. Attributes: subscript (Subscript): The subscript defining the einsum operation. """ subscript: Subscript def __init__(self, subscript: str | Subscript): """ Initialize an Einsum object. Args: subscript (str | Subscript): The subscript defining the einsum operation. Can be a string or a Subscript object. """ if isinstance(subscript, str): subscript = Subscript(subscript) self.subscript = subscript def _build_back_ops(self, tensors: Sequence[Tensor], subscript: Subscript): """ Figure out the backward operation for each input. Args: tensors (Sequence[Tensor]): The input tensors. subscript (Subscript): The subscript for the operation. Returns: list: A list of EinsumBackOp objects, one for each input tensor. """ assert len(tensors) == len(subscript.inputs) # To avoid adding a bunch of special cases for batched # operations, we replace any batched operations with # their non-batched counterparts subscript = Subscript(subscript.subscript.replace("z", "")) back_functions = [] for idx in range(len(tensors)): back_op = EinsumBackOp(idx, tensors, subscript) back_functions.append(back_op) return back_functions def _handle_single_tensor( self, subscript: Subscript, tensors: Sequence[Tensor] ) -> tuple[Subscript, Sequence[Tensor]]: """ Handle the case of a single input tensor. If there is only one tensor, we need to insert a matrix of ones to allow for expansion operations. Args: subscript (Subscript): The original subscript. tensors (Sequence[Tensor]): The input tensors. Returns: tuple: A tuple containing two elements: - Subscript: The modified subscript. - Sequence[Tensor]: The modified list of tensors. """ xp = select_backend(*tensors) if len(tensors) != 1: return subscript, tensors [tensor] = tensors ones = Tensor( xp.ones(tensor.shape), is_batched=tensor.is_batched, requires_grad=False, ) tensors = [tensor, ones] [index] = subscript.inputs output = subscript.output inputs = [index, index] subscript = Subscript.from_split(inputs, output) return subscript, tensors def _handle_batched( self, subscript: Subscript, tensors: Sequence[Tensor] ) -> tuple[Subscript, Sequence[Tensor], bool]: """ Handle batched tensors in the einsum operation. If a tensor is labelled as being batched, add an extra dimension to its indices. Args: subscript (Subscript): The original subscript. tensors (Sequence[Tensor]): The input tensors. Returns: tuple: A tuple containing three elements: - Subscript: The modified subscript. - Sequence[Tensor]: The input tensors (unchanged). - bool: Whether the output should be batched. Raises: ValueError: If 'z' is used in the subscript for non-batched tensors. """ inputs = [] batch_output = False for idx, tensor in zip(subscript.inputs, tensors): if tensor.is_batched: inputs.append(["z"] + idx) batch_output = True else: inputs.append(idx) output = subscript.output if batch_output: if "z" in subscript.subscript: raise ValueError( "`z` cannot be used in an einsum subscript on " "non-batched tensors because " "it is reserved for batched indices." ) output = ["z"] + output subscript = Subscript.from_split(inputs, output) return subscript, tensors, batch_output def _replace_infinity(self, tensors: Sequence[Tensor]): """ Replace infinity values in tensors with the max value for that datatype. Args: tensors (Sequence[Tensor]): The input tensors. Returns: list: A list of processed tensors with infinity values replaced. """ xp = select_backend(*tensors) processed = [] for tensor in tensors: if not xp.isinf(tensor.array).any(): processed.append(tensor) continue new_tensor = Tensor( xp.nan_to_num(tensor.array), is_batched=tensor.is_batched, ) new_tensor.args = tensor.args new_tensor.back_fns = tensor.back_fns new_tensor.name = tensor.name processed.append(new_tensor) return processed def __call__(self, *tensors: Tensor): """ Perform the einsum operation on the input tensors. Args: *tensors (Tensor): The input tensors for the einsum operation. Returns: Tensor: The result of the einsum operation. """ xp = select_backend(*tensors) subscript, tensors, batch_output = self._handle_batched( self.subscript, tensors ) subscript, tensors = self._handle_single_tensor(subscript, tensors) tensor_data = [t.array for t in tensors] result = Tensor(xp.einsum(str(subscript), *tensor_data)) if batch_output: result.is_batched = True result.args = tuple(tensors) result.back_fns = tuple(self._build_back_ops(tensors, subscript)) result.name = f"einsum {self.subscript}" return result