einsum

Einsum implementation for generalized matrix operations.

This module provides an implementation of the einsum operation, which is a generalization of many matrix operations. It allows for flexible manipulation of tensors using index notation.

Example usage:
>>> a = Tensor([[1,2],[3,4]])
>>> Einsum("ij->ji")(a)
Tensor([[1. 3.]
 [2. 4.]], name=einsum ij->ji)

For more details on einsum operations, refer to the class and function docstrings.

class Einsum(subscript)[source]

Bases: object

A class representing an einsum operation.

This class encapsulates the logic for performing einsum operations on tensors, including handling of batched operations and backward passes.

Parameters:

subscript (Subscript)

subscript

The subscript defining the einsum operation.

Type:

Subscript

subscript: Subscript
class EinsumBackOp(idx, tensors, subscript)[source]

Bases: object

The backward operation for an einsum operation.

This class represents the backward pass of an einsum operation, which is done by swapping the indices and tensors for an input with the output.

Parameters:
idx

The index of the input tensor for which this backward operation is defined.

Type:

int

tensors

The input tensors of the original einsum operation.

Type:

Sequence[Tensor]

subscript

The subscript of the original einsum operation.

Type:

Subscript

left_tensors

Tensors to the left of the current input in the original operation.

Type:

Sequence[Tensor]

right_tensors

Tensors to the right of the current input in the original operation.

Type:

Sequence[Tensor]

combined_subscript

The subscript for the backward operation.

Type:

Subscript

class Subscript(subscript)[source]

Bases: object

A string that defines an einsum operation.

This class parses and represents the subscript notation used in einsum operations.

Parameters:

subscript (str)

subscript

The original subscript string.

Type:

str

inputs

Parsed input indices.

Type:

list[list[str]]

output

Parsed output indices.

Type:

list[str]

classmethod from_split(indices, result)[source]

Create a Subscript object from split indices and result.

Parameters:
  • indices (list[list[str]]) – Parsed input indices.

  • result (list[str]) – Parsed output indices.

Returns:

A new Subscript object.

Return type:

Subscript

inputs: list[list[str]]
static join(indices, result)[source]

Join parsed indices and result back into a subscript string.

Parameters:
  • indices (list[list[str]]) – Parsed input indices.

  • result (list[str]) – Parsed output indices.

Returns:

The joined subscript string.

Return type:

str

output: list[str]
split()[source]

Parse a subscripts string into a list of indices and a result.

Returns:

A tuple containing two elements:
  • list[list[str]]: Parsed input indices.

  • list[str]: Parsed output indices.

Return type:

tuple

subscript: str
property unique_input_indices: set[str]

Get the set of unique input indices.

Returns:

Set of unique input indices.

Return type:

set[str]