Layers
- class cellarium.ml.layers.GeneExpressionEmbedding(categorical_vocab_sizes: dict[str, int], continuous_vocab_sizes: dict[str, int], d_model: int, embeddings_initializer: dict[str, Any])[source]
Bases:
Module
Gene embedding.
- Parameters:
categorical_vocab_sizes (dict[str, int]) – Categorical gene token vocabulary sizes.
continuous_vocab_sizes (dict[str, int]) – Continuous gene token vocabulary sizes.
d_model (int) – Dimensionality of the embeddings and hidden states.
embeddings_initializer (dict[str, Any]) – Initializer for the embeddings.
- class cellarium.ml.layers.MetadataEmbedding(categorical_vocab_sizes: dict[str, int], d_model: int, embeddings_initializer: dict[str, Any])[source]
Bases:
Module
Metadata embedding.
- Parameters:
categorical_vocab_sizes (dict[str, int]) – Categorical metadata token vocabulary sizes.
d_model (int) – Dimensionality of the embeddings and hidden states.
initializer – Initializer for the embeddings.
embeddings_initializer (dict[str, Any])
- class cellarium.ml.layers.MuLinear(in_features: int, out_features: int, bias: bool, layer: Literal['input', 'hidden', 'output'], optimizer: Literal['sgd', 'adam', 'adamw'], weight_init_std: float = 1.0, bias_init_std: float = 0.0, lr_scale: float = 1.0, base_width: int = 1)[source]
Bases:
Module
Linear layer with a maximal update parametrization.
The maximal update parametrization for SGD is defined by:
Input & Biases
Hidden
Output
\(a\)
\(-0.5\)
\(0\)
\(0.5\)
\(b\)
\(0.5\)
\(0.5\)
\(0.5\)
\(c\)
\(0\)
\(0\)
\(0\)
\(d\)
\(0\)
\(0\)
\(0\)
\(n\)
out_features
in_features
in_features
The maximal update parametrization for Adam and AdamW is defined by
Input & Biases
Hidden
Output
\(a\)
\(0\)
\(1\)
\(1\)
\(b\)
\(0\)
\(-0.5\)
\(0\)
\(c\)
\(0\)
\(0\)
\(0\)
\(d\)
\(1\)
\(2\)
\(1\)
\(n\)
out_features
in_features
in_features
Since in this implementation \(c\) always equals 0, regular PyTorch optimizers can be used.
References:
Feature Learning in Infinite-Width Neural Networks (Yang et al.).
Tensor Programs IVb: Adaptive Optimization in the ∞-Width Limit (Yang et al.).
- Parameters:
in_features (int) – Size of each input sample.
out_features (int) – Size of each output sample.
bias (bool) – If set to
False
, the layer will not learn an additive bias.layer (Literal['input', 'hidden', 'output']) – Layer type.
optimizer (Literal['sgd', 'adam', 'adamw']) – Optimizer type.
weight_init_std (float) – The standard deviation of the weight initialization at base width.
bias_init_std (float) – The standard deviation of the bias initialization at base width.
lr_scale (float) – The learning rate scaling factor for the weight and the bias.
base_width (int) – The base width of the layer.
- property weight: Tensor
The weights of the module of shape
(out_features, in_features)
. The weight-specific learning rate and the initialization standard deviation are scaled with the width of the layer according to the table above.
- class cellarium.ml.layers.MultiHeadAttention(d_model: int, use_bias: bool, n_heads: int, dropout_p: float, attention_logits_scale: float, attention_backend: Literal['flex', 'math', 'mem_efficient', 'torch'], attention_softmax_fp32: bool, Wqkv_initializer: dict[str, Any], Wo_initializer: dict[str, Any])[source]
Bases:
Module
Multi-head attention.
- Parameters:
d_model (int) – Dimensionality of the embeddings and hidden states.
use_bias (bool) – Whether to use bias in the linear transformations.
n_heads (int) – Number of attention heads.
dropout_p (float) – Dropout probability.
attention_logits_scale (float) – Multiplier for the attention scores.
attention_backend (Literal['flex', 'math', 'mem_efficient', 'torch']) – Backend for the attention computation.
attention_softmax_fp32 (bool) – Whether to use float32 for softmax computation when
torch
backend is used.Wqkv_initializer (dict[str, Any]) – Initializer for the query, key, and value linear transformations.
Wo_initializer (dict[str, Any]) – Initializer for the output linear transformation.
- static split_heads(X_nqd: Tensor, n_heads: int) Tensor [source]
Transposition for parallel computation of multiple attention heads.
- Parameters:
X_nqd (Tensor)
n_heads (int)
- Return type:
Tensor
- static merge_heads(X_nhqk: Tensor) Tensor [source]
Reverse of split_heads.
- Parameters:
X_nhqk (Tensor)
- Return type:
Tensor
- forward(x_query_ncd: Tensor, x_key_ncd: Tensor, x_value_ncd: Tensor, attention_mask_ncc: Tensor | BlockMask) Tensor [source]
- Parameters:
x_query_ncd (Tensor) – Input query tensor of shape
(n, c, d)
.x_key_ncd (Tensor) – Input key tensor of shape
(n, c, d)
.x_value_ncd (Tensor) – Input value tensor of shape
(n, c, d)
.attention_mask_ncc (Tensor | BlockMask) – Attention mask tensor of shape
(n, c, c)
.
- Returns:
The output hidden state tensor of shape
(n, c, d)
.- Return type:
Tensor
- class cellarium.ml.layers.MultiHeadReadout(categorical_vocab_sizes: dict[str, int], d_model: int, use_bias: bool, output_logits_scale: float, heads_initializer: dict[str, Any])[source]
Bases:
Module
Multi-head readout.
- Parameters:
categorical_vocab_sizes (dict[str, int]) – Categorical token vocabulary sizes.
d_model (int) – Dimensionality of the embeddings and hidden states.
use_bias (bool) – Whether to use bias in the linear transformations.
output_logits_scale (float) – Multiplier for the output logits.
heads_initializer (dict[str, Any]) – Initializer for the output linear transformations.
- class cellarium.ml.layers.NormAdd(norm_shape: int, dropout_p: float, use_bias: bool)[source]
Bases:
Module
Pre-norm layer where the layer normalization is applied before the sublayer.
- Parameters:
norm_shape (int) – The shape of the layer normalization.
dropout_p (float) – Dropout probability.
use_bias (bool) – Whether to use bias in the layer normalization.
- forward(hidden_state_ncd: Tensor, sublayer: Callable[[Tensor], Tensor]) Tensor [source]
- Parameters:
hidden_state_ncd (Tensor) – Hidden state tensor of shape
(n, c, d)
.sublayer (Callable[[Tensor], Tensor]) – Sublayer function.
- Returns:
The output hidden state tensor of shape
(n, c, d)
.- Return type:
Tensor
- class cellarium.ml.layers.PositionWiseFFN(d_ffn: int, d_model: int, use_bias: bool, dense1_initializer: dict[str, Any], dense2_initializer: dict[str, Any])[source]
Bases:
Module
The positionwise feed-forward network.
- Parameters:
d_ffn (int) – Dimensionality of the inner feed-forward layers.
d_model (int) – Dimensionality of the embeddings and hidden states.
use_bias (bool) – Whether to use bias in the linear transformations.
dense1_initializer (dict[str, Any]) – Initializer for the first dense layer.
dense2_initializer (dict[str, Any]) – Initializer for the second dense layer.
- class cellarium.ml.layers.Transformer(d_model: int, d_ffn: int, use_bias: bool, n_heads: int, n_blocks: int, dropout_p: float, attention_logits_scale: float, attention_backend: Literal['flex', 'math', 'mem_efficient', 'torch'], attention_softmax_fp32: bool, Wqkv_initializer: dict[str, Any], Wo_initializer: dict[str, Any], dense1_initializer: dict[str, Any], dense2_initializer: dict[str, Any])[source]
Bases:
Module
Transformer model.
- Parameters:
d_model (int) – Dimensionality of the embeddings and hidden states.
d_ffn (int) – Dimensionality of the inner feed-forward layers.
use_bias (bool) – Whether to use bias in the linear transformations and norm-add layers.
n_heads (int) – Number of attention heads.
n_blocks (int) – Number of transformer blocks.
dropout_p (float) – Dropout probability.
attention_logits_scale (float) – Multiplier for the attention scores.
attention_backend (Literal['flex', 'math', 'mem_efficient', 'torch']) – Backend for the attention computation.
attention_softmax_fp32 (bool) – Whether to use FP32 for the softmax computation when
torch
backend is used.Wqkv_initializer (dict[str, Any]) – Initializer for the query, key, and value linear transformations.
Wo_initializer (dict[str, Any]) – Initializer for the output linear transformation.
dense1_initializer (dict[str, Any]) – Initializer for the first dense layer.
dense2_initializer (dict[str, Any]) – Initializer for the second dense layer.
- forward(hidden_state_ncd: Tensor, attention_mask_ncc: Tensor | BlockMask) Tensor [source]
- Parameters:
hidden_state_ncd (Tensor) – Hidden state tensor of shape
(n, c, d)
.attention_mask_ncc (Tensor | BlockMask) – Attention mask tensor of shape
(n, c, c)
.
- Returns:
The output hidden state tensor of shape
(n, c, d)
.- Return type:
Tensor
- class cellarium.ml.layers.TransformerBlock(d_model: int, d_ffn: int, use_bias: bool, n_heads: int, dropout_p: float, attention_logits_scale: float, attention_backend: Literal['flex', 'math', 'mem_efficient', 'torch'], attention_softmax_fp32: bool, Wqkv_initializer: dict[str, Any], Wo_initializer: dict[str, Any], dense1_initializer: dict[str, Any], dense2_initializer: dict[str, Any])[source]
Bases:
Module
Transformer block.
- Parameters:
d_model (int) – Dimensionality of the embeddings and hidden states.
d_ffn (int) – Dimensionality of the inner feed-forward layers.
use_bias (bool) – Whether to use bias in the linear transformations and norm-add layers.
n_heads (int) – Number of attention heads.
dropout_p (float) – Dropout probability.
attention_logits_scale (float) – Multiplier for the attention scores.
attention_backend (Literal['flex', 'math', 'mem_efficient', 'torch']) – Backend for the attention computation.
attention_softmax_fp32 (bool) – Whether to use FP32 for the softmax computation when
torch
backend is used.Wqkv_initializer (dict[str, Any]) – Initializer for the query, key, and value linear transformations.
Wo_initializer (dict[str, Any]) – Initializer for the output linear transformation.
dense1_initializer (dict[str, Any]) – Initializer for the first dense layer.
dense2_initializer (dict[str, Any]) – Initializer for the second dense layer.
- forward(hidden_state_ncd: Tensor, attention_mask_ncc: Tensor | BlockMask) Tensor [source]
- Parameters:
hidden_state_ncd (Tensor) – Hidden state tensor of shape
(n, c, d)
.attention_mask_ncc (Tensor | BlockMask) – Attention mask tensor of shape
(n, c, c)
.
- Returns:
The output hidden state tensor of shape
(n, c, d)
.- Return type:
Tensor