einsum¶
Einsum implementation for generalized matrix operations.
This module provides an implementation of the einsum operation, which is a generalization of many matrix operations. It allows for flexible manipulation of tensors using index notation.
- Example usage:
- >>> a = Tensor([[1,2],[3,4]]) >>> Einsum("ij->ji")(a) Tensor([[1. 3.] [2. 4.]], name=einsum ij->ji) 
For more details on einsum operations, refer to the class and function docstrings.
- class Einsum(subscript)[source]¶
- Bases: - object- A class representing an einsum operation. - This class encapsulates the logic for performing einsum operations on tensors, including handling of batched operations and backward passes. - Parameters:
- subscript (Subscript) 
 
- class EinsumBackOp(idx, tensors, subscript)[source]¶
- Bases: - object- The backward operation for an einsum operation. - This class represents the backward pass of an einsum operation, which is done by swapping the indices and tensors for an input with the output. - idx¶
- The index of the input tensor for which this backward operation is defined. - Type:
- int 
 
 - left_tensors¶
- Tensors to the left of the current input in the original operation. - Type:
- Sequence[Tensor] 
 
 
- class Subscript(subscript)[source]¶
- Bases: - object- A string that defines an einsum operation. - This class parses and represents the subscript notation used in einsum operations. - Parameters:
- subscript (str) 
 - subscript¶
- The original subscript string. - Type:
- str 
 
 - inputs¶
- Parsed input indices. - Type:
- list[list[str]] 
 
 - output¶
- Parsed output indices. - Type:
- list[str] 
 
 - classmethod from_split(indices, result)[source]¶
- Create a Subscript object from split indices and result. - Parameters:
- indices (list[list[str]]) – Parsed input indices. 
- result (list[str]) – Parsed output indices. 
 
- Returns:
- A new Subscript object. 
- Return type:
 
 - inputs: list[list[str]]¶
 - static join(indices, result)[source]¶
- Join parsed indices and result back into a subscript string. - Parameters:
- indices (list[list[str]]) – Parsed input indices. 
- result (list[str]) – Parsed output indices. 
 
- Returns:
- The joined subscript string. 
- Return type:
- str 
 
 - output: list[str]¶
 - split()[source]¶
- Parse a subscripts string into a list of indices and a result. - Returns:
- A tuple containing two elements:
- list[list[str]]: Parsed input indices. 
- list[str]: Parsed output indices. 
 
 
- Return type:
- tuple 
 
 - subscript: str¶
 - property unique_input_indices: set[str]¶
- Get the set of unique input indices. - Returns:
- Set of unique input indices. 
- Return type:
- set[str]