attention

Attention module for multi-head attention operations.

This module implements the multi-head attention mechanism as described in “Attention Is All You Need” (Vaswani et al., 2017). It includes functions for building attention masks and the main Attention class for performing multi-head attention operations.

class Attention(embedding_dim, n_heads, context_window)[source]

Bases: Op

Multi-head attention operation.

This class implements the multi-head attention mechanism as described in “Attention Is All You Need” (Vaswani et al., 2017).

Parameters:
  • embedding_dim (int)

  • n_heads (int)

  • context_window (int)

embedding_dim

An integer representing the dimension of the input embeddings.

n_heads

An integer representing the number of attention heads.

context_window

An integer representing the size of the context window.

mask

A tensor representing the attention mask.

_grad

A tensor to store gradients during backpropagation.

backward(grad)[source]

Compute the gradient of the attention operation.

Parameters:

grad (Tensor) – A Tensor representing the upstream gradient.

Returns:

A Tensor representing the gradient with respect to the input.

forward(tensor)[source]

Apply the multi-head attention operation to the input tensor.

Parameters:

tensor (Tensor) – A Tensor of shape (batch_size, seq_len, embedding_dim * 3). The input should contain concatenated query, key, and value projections.

Returns:

A Tensor representing the output after applying multi-head attention.

from_gpu()[source]

Move the operation back to CPU.

to_gpu(device)[source]

Move this operation to a GPU.

Parameters:

device (int) – An integer representing the GPU device number.

build_mask(context_window, n_heads)[source]

Build an attention mask to prevent attending to future tokens.

This function creates a boolean mask that can be used in multi-head attention mechanisms to implement causal (unidirectional) attention.

Parameters:
  • context_window (int) – An integer representing the size of the context window.

  • n_heads (int) – An integer representing the number of attention heads.

Returns:

A boolean tensor of shape (n_heads, context_window, context_window) representing the attention mask.

Return type:

Tensor