# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Literal
import torch
from torch import nn
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import BlockMask, flex_attention
from cellarium.ml.utilities.layers import create_initializer
try:
# use_cs returns True if the active device is a CSX device.
from cerebras.pytorch.backend import use_cs
except ImportError:
def use_cs() -> bool:
return False
compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
[docs]
class MultiHeadAttention(nn.Module):
"""
Multi-head attention.
Args:
d_model:
Dimensionality of the embeddings and hidden states.
use_bias:
Whether to use bias in the linear transformations.
n_heads:
Number of attention heads.
dropout_p:
Dropout probability.
attention_logits_scale:
Multiplier for the attention scores.
attention_backend:
Backend for the attention computation.
attention_softmax_fp32:
Whether to use float32 for softmax computation when ``torch`` backend is used.
Wqkv_initializer:
Initializer for the query, key, and value linear transformations.
Wo_initializer:
Initializer for the output linear transformation.
"""
backend_map = {
"math": SDPBackend.MATH,
"flash": SDPBackend.FLASH_ATTENTION,
"mem_efficient": SDPBackend.EFFICIENT_ATTENTION,
}
def __init__(
self,
d_model: int,
use_bias: bool,
n_heads: int,
dropout_p: float,
attention_logits_scale: float,
attention_backend: Literal["flex", "math", "mem_efficient", "torch"],
attention_softmax_fp32: bool,
Wqkv_initializer: dict[str, Any],
Wo_initializer: dict[str, Any],
) -> None:
super().__init__()
self.Wq = nn.Linear(d_model, d_model, bias=use_bias)
self.Wk = nn.Linear(d_model, d_model, bias=use_bias)
self.Wv = nn.Linear(d_model, d_model, bias=use_bias)
self.Wo = nn.Linear(d_model, d_model, bias=use_bias)
self.n_heads = n_heads
self.dropout_p = dropout_p
self.attention_logits_scale = attention_logits_scale
self.attention_backend = attention_backend
self.attention_softmax_fp32 = attention_softmax_fp32
self.Wqkv_initializer = Wqkv_initializer
self.Wo_initializer = Wo_initializer
self._reset_parameters()
def _reset_parameters(self) -> None:
for module in [self.Wq, self.Wk, self.Wv]:
create_initializer(self.Wqkv_initializer)(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
create_initializer(self.Wo_initializer)(self.Wo.weight)
if self.Wo.bias is not None:
nn.init.zeros_(self.Wo.bias)
[docs]
@staticmethod
def split_heads(X_nqd: torch.Tensor, n_heads: int) -> torch.Tensor:
"""Transposition for parallel computation of multiple attention heads."""
X_nqhk = X_nqd.reshape(X_nqd.shape[0], X_nqd.shape[1], n_heads, -1)
X_nhqk = X_nqhk.permute(0, 2, 1, 3)
return X_nhqk
[docs]
@staticmethod
def merge_heads(X_nhqk: torch.Tensor) -> torch.Tensor:
"""Reverse of split_heads."""
X_nqhk = X_nhqk.permute(0, 2, 1, 3)
X_nqd = X_nqhk.reshape(X_nqhk.shape[0], X_nqhk.shape[1], -1)
return X_nqd
[docs]
def forward(
self,
x_query_ncd: torch.Tensor,
x_key_ncd: torch.Tensor,
x_value_ncd: torch.Tensor,
attention_mask_ncc: torch.Tensor | BlockMask,
) -> torch.Tensor:
"""
Args:
x_query_ncd:
Input query tensor of shape ``(n, c, d)``.
x_key_ncd:
Input key tensor of shape ``(n, c, d)``.
x_value_ncd:
Input value tensor of shape ``(n, c, d)``.
attention_mask_ncc:
Attention mask tensor of shape ``(n, c, c)``.
Returns:
The output hidden state tensor of shape ``(n, c, d)``.
"""
n_heads = self.n_heads
query_ncd = self.Wq(x_query_ncd)
key_ncd = self.Wk(x_key_ncd)
value_ncd = self.Wv(x_value_ncd)
# d = k * h
query_nhck = self.split_heads(query_ncd, n_heads)
key_nhck = self.split_heads(key_ncd, n_heads)
value_nhck = self.split_heads(value_ncd, n_heads)
# scale_factor is computed according to the muP paper
scale_factor = self.attention_logits_scale / query_nhck.shape[-1]
if (self.attention_backend == "torch") or use_cs():
assert isinstance(attention_mask_ncc, torch.Tensor)
key_nhck = key_nhck * torch.tensor(scale_factor, dtype=key_nhck.dtype)
attention_logits_nhcc = torch.matmul(query_nhck, key_nhck.transpose(-1, -2))
neg_inf = torch.tensor(float("-inf"), dtype=torch.float32)
attention_bias_ncc = torch.where(attention_mask_ncc, 0, neg_inf).to(attention_logits_nhcc.dtype)
attention_logits_nhcc += attention_bias_ncc.unsqueeze(1).expand(attention_logits_nhcc.shape)
if self.attention_softmax_fp32 and attention_logits_nhcc.dtype != torch.float32:
attention_weights_nhcc = nn.functional.softmax(attention_logits_nhcc.float(), dim=-1).to(
attention_logits_nhcc.dtype
)
else:
attention_weights_nhcc = nn.functional.softmax(attention_logits_nhcc, dim=-1)
attention_weights_nhcc = nn.functional.dropout(
attention_weights_nhcc, self.dropout_p, training=self.training
)
output_nhck = torch.matmul(attention_weights_nhcc, value_nhck)
elif self.attention_backend == "flex":
assert isinstance(attention_mask_ncc, BlockMask)
if self.dropout_p > 0.0:
raise NotImplementedError("Dropout is not yet supported for flex_attention")
output_nhck = compiled_flex_attention( # type: ignore[assignment]
query_nhck,
key_nhck,
value_nhck,
block_mask=attention_mask_ncc,
scale=scale_factor,
)
else:
assert isinstance(attention_mask_ncc, torch.Tensor)
with sdpa_kernel(self.backend_map[self.attention_backend]):
output_nhck = nn.functional.scaled_dot_product_attention(
query_nhck,
key_nhck,
value_nhck,
attn_mask=attention_mask_ncc.unsqueeze(1),
dropout_p=self.dropout_p if self.training else 0.0,
scale=scale_factor,
)
output_ncd = self.merge_heads(output_nhck)
return self.Wo(output_ncd) # _ncd