Source code for cellarium.ml.layers.transformer
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Literal
import torch
from torch import nn
from torch.nn.attention.flex_attention import BlockMask
from cellarium.ml.layers.attention import MultiHeadAttention
from cellarium.ml.layers.ffn import PositionWiseFFN
from cellarium.ml.layers.normadd import NormAdd
[docs]
class TransformerBlock(nn.Module):
"""
Transformer block.
Args:
d_model:
Dimensionality of the embeddings and hidden states.
d_ffn:
Dimensionality of the inner feed-forward layers.
use_bias:
Whether to use bias in the linear transformations and norm-add layers.
n_heads:
Number of attention heads.
dropout_p:
Dropout probability.
attention_logits_scale:
Multiplier for the attention scores.
attention_backend:
Backend for the attention computation.
attention_softmax_fp32:
Whether to use FP32 for the softmax computation when ``torch`` backend is used.
Wqkv_initializer:
Initializer for the query, key, and value linear transformations.
Wo_initializer:
Initializer for the output linear transformation.
dense1_initializer:
Initializer for the first dense layer.
dense2_initializer:
Initializer for the second dense layer.
"""
def __init__(
self,
d_model: int,
d_ffn: int,
use_bias: bool,
n_heads: int,
dropout_p: float,
attention_logits_scale: float,
attention_backend: Literal["flex", "math", "mem_efficient", "torch"],
attention_softmax_fp32: bool,
Wqkv_initializer: dict[str, Any],
Wo_initializer: dict[str, Any],
dense1_initializer: dict[str, Any],
dense2_initializer: dict[str, Any],
) -> None:
super().__init__()
if d_model % n_heads != 0:
raise ValueError(f"`d_model` ({d_model}) must be divisible by `n_heads` ({n_heads})")
self.attention = MultiHeadAttention(
d_model,
use_bias,
n_heads,
dropout_p,
attention_logits_scale,
attention_backend,
attention_softmax_fp32,
Wqkv_initializer,
Wo_initializer,
)
self.normadd1 = NormAdd(d_model, dropout_p, use_bias)
self.ffn = PositionWiseFFN(
d_ffn,
d_model,
use_bias,
dense1_initializer,
dense2_initializer,
)
self.normadd2 = NormAdd(d_model, dropout_p, use_bias)
[docs]
def forward(
self,
hidden_state_ncd: torch.Tensor,
attention_mask_ncc: torch.Tensor | BlockMask,
) -> torch.Tensor:
"""
Args:
hidden_state_ncd:
Hidden state tensor of shape ``(n, c, d)``.
attention_mask_ncc:
Attention mask tensor of shape ``(n, c, c)``.
Returns:
The output hidden state tensor of shape ``(n, c, d)``.
"""
hidden_state_ncd = self.normadd1(hidden_state_ncd, lambda X: self.attention(X, X, X, attention_mask_ncc))
return self.normadd2(hidden_state_ncd, lambda Y: self.ffn(Y)) # _ncd
[docs]
class Transformer(nn.Module):
"""
Transformer model.
Args:
d_model:
Dimensionality of the embeddings and hidden states.
d_ffn:
Dimensionality of the inner feed-forward layers.
use_bias:
Whether to use bias in the linear transformations and norm-add layers.
n_heads:
Number of attention heads.
n_blocks:
Number of transformer blocks.
dropout_p:
Dropout probability.
attention_logits_scale:
Multiplier for the attention scores.
attention_backend:
Backend for the attention computation.
attention_softmax_fp32:
Whether to use FP32 for the softmax computation when ``torch`` backend is used.
Wqkv_initializer:
Initializer for the query, key, and value linear transformations.
Wo_initializer:
Initializer for the output linear transformation.
dense1_initializer:
Initializer for the first dense layer.
dense2_initializer:
Initializer for the second dense layer.
"""
def __init__(
self,
d_model: int,
d_ffn: int,
use_bias: bool,
n_heads: int,
n_blocks: int,
dropout_p: float,
attention_logits_scale: float,
attention_backend: Literal["flex", "math", "mem_efficient", "torch"],
attention_softmax_fp32: bool,
Wqkv_initializer: dict[str, Any],
Wo_initializer: dict[str, Any],
dense1_initializer: dict[str, Any],
dense2_initializer: dict[str, Any],
) -> None:
super().__init__()
self.n_blocks = n_blocks
self.blocks = nn.ModuleList(
[
TransformerBlock(
d_model,
d_ffn,
use_bias,
n_heads,
dropout_p,
attention_logits_scale,
attention_backend,
attention_softmax_fp32,
Wqkv_initializer,
Wo_initializer,
dense1_initializer,
dense2_initializer,
)
for _ in range(n_blocks)
]
)
self.ln = nn.LayerNorm(d_model, bias=use_bias)
self._reset_parameters()
def _reset_parameters(self) -> None:
self.ln.reset_parameters()
[docs]
def forward(
self,
hidden_state_ncd: torch.Tensor,
attention_mask_ncc: torch.Tensor | BlockMask,
) -> torch.Tensor:
"""
Args:
hidden_state_ncd:
Hidden state tensor of shape ``(n, c, d)``.
attention_mask_ncc:
Attention mask tensor of shape ``(n, c, c)``.
Returns:
The output hidden state tensor of shape ``(n, c, d)``.
"""
for block in self.blocks:
hidden_state_ncd = block(hidden_state_ncd, attention_mask_ncc)
return self.ln(hidden_state_ncd)