einsum¶
Einsum implementation for generalized matrix operations.
This module provides an implementation of the einsum operation, which is a generalization of many matrix operations. It allows for flexible manipulation of tensors using index notation.
- Example usage:
>>> a = Tensor([[1,2],[3,4]]) >>> Einsum("ij->ji")(a) Tensor([[1. 3.] [2. 4.]], name=einsum ij->ji)
For more details on einsum operations, refer to the class and function docstrings.
- class Einsum(subscript)[source]¶
Bases:
object
A class representing an einsum operation.
This class encapsulates the logic for performing einsum operations on tensors, including handling of batched operations and backward passes.
- Parameters:
subscript (Subscript)
- class EinsumBackOp(idx, tensors, subscript)[source]¶
Bases:
object
The backward operation for an einsum operation.
This class represents the backward pass of an einsum operation, which is done by swapping the indices and tensors for an input with the output.
- idx¶
The index of the input tensor for which this backward operation is defined.
- Type:
int
- left_tensors¶
Tensors to the left of the current input in the original operation.
- Type:
Sequence[Tensor]
- class Subscript(subscript)[source]¶
Bases:
object
A string that defines an einsum operation.
This class parses and represents the subscript notation used in einsum operations.
- Parameters:
subscript (str)
- subscript¶
The original subscript string.
- Type:
str
- inputs¶
Parsed input indices.
- Type:
list[list[str]]
- output¶
Parsed output indices.
- Type:
list[str]
- classmethod from_split(indices, result)[source]¶
Create a Subscript object from split indices and result.
- Parameters:
indices (list[list[str]]) – Parsed input indices.
result (list[str]) – Parsed output indices.
- Returns:
A new Subscript object.
- Return type:
- inputs: list[list[str]]¶
- static join(indices, result)[source]¶
Join parsed indices and result back into a subscript string.
- Parameters:
indices (list[list[str]]) – Parsed input indices.
result (list[str]) – Parsed output indices.
- Returns:
The joined subscript string.
- Return type:
str
- output: list[str]¶
- split()[source]¶
Parse a subscripts string into a list of indices and a result.
- Returns:
- A tuple containing two elements:
list[list[str]]: Parsed input indices.
list[str]: Parsed output indices.
- Return type:
tuple
- subscript: str¶
- property unique_input_indices: set[str]¶
Get the set of unique input indices.
- Returns:
Set of unique input indices.
- Return type:
set[str]