Layers

class cellarium.ml.layers.GeneExpressionEmbedding(categorical_vocab_sizes: dict[str, int], continuous_vocab_sizes: dict[str, int], d_model: int, embeddings_initializer: dict[str, Any])[source]

Bases: Module

Gene embedding.

Parameters:
  • categorical_vocab_sizes (dict[str, int]) – Categorical gene token vocabulary sizes.

  • continuous_vocab_sizes (dict[str, int]) – Continuous gene token vocabulary sizes.

  • d_model (int) – Dimensionality of the embeddings and hidden states.

  • embeddings_initializer (dict[str, Any]) – Initializer for the embeddings.

forward(gene_tokens_nc: dict[str, Tensor]) Tensor[source]
Parameters:

gene_tokens_nc (dict[str, Tensor]) – Dictionary of gene token tensors of shape (n, c).

Returns:

The gene embedding tensor of shape (n, c, d).

Return type:

Tensor

class cellarium.ml.layers.MetadataEmbedding(categorical_vocab_sizes: dict[str, int], d_model: int, embeddings_initializer: dict[str, Any])[source]

Bases: Module

Metadata embedding.

Parameters:
  • categorical_vocab_sizes (dict[str, int]) – Categorical metadata token vocabulary sizes.

  • d_model (int) – Dimensionality of the embeddings and hidden states.

  • initializer – Initializer for the embeddings.

  • embeddings_initializer (dict[str, Any])

forward(metadata_tokens_n: dict[str, Tensor]) Tensor[source]
Parameters:
  • metadata_token_n – Dictionary of metadata token tensors of shape (n,).

  • metadata_tokens_n (dict[str, Tensor])

Returns:

The metadata embedding tensor of shape (n, m, d).

Return type:

Tensor

class cellarium.ml.layers.MuLinear(in_features: int, out_features: int, bias: bool, layer: Literal['input', 'hidden', 'output'], optimizer: Literal['sgd', 'adam', 'adamw'], weight_init_std: float = 1.0, bias_init_std: float = 0.0, lr_scale: float = 1.0, base_width: int = 1)[source]

Bases: Module

Linear layer with a maximal update parametrization.

The maximal update parametrization for SGD is defined by:

Input & Biases

Hidden

Output

\(a\)

\(-0.5\)

\(0\)

\(0.5\)

\(b\)

\(0.5\)

\(0.5\)

\(0.5\)

\(c\)

\(0\)

\(0\)

\(0\)

\(d\)

\(0\)

\(0\)

\(0\)

\(n\)

out_features

in_features

in_features

The maximal update parametrization for Adam and AdamW is defined by

Input & Biases

Hidden

Output

\(a\)

\(0\)

\(1\)

\(1\)

\(b\)

\(0\)

\(-0.5\)

\(0\)

\(c\)

\(0\)

\(0\)

\(0\)

\(d\)

\(1\)

\(2\)

\(1\)

\(n\)

out_features

in_features

in_features

Since in this implementation \(c\) always equals 0, regular PyTorch optimizers can be used.

References:

  1. Feature Learning in Infinite-Width Neural Networks (Yang et al.).

  2. Tensor Programs IVb: Adaptive Optimization in the ∞-Width Limit (Yang et al.).

Parameters:
  • in_features (int) – Size of each input sample.

  • out_features (int) – Size of each output sample.

  • bias (bool) – If set to False, the layer will not learn an additive bias.

  • layer (Literal['input', 'hidden', 'output']) – Layer type.

  • optimizer (Literal['sgd', 'adam', 'adamw']) – Optimizer type.

  • weight_init_std (float) – The standard deviation of the weight initialization at base width.

  • bias_init_std (float) – The standard deviation of the bias initialization at base width.

  • lr_scale (float) – The learning rate scaling factor for the weight and the bias.

  • base_width (int) – The base width of the layer.

property weight: Tensor

The weights of the module of shape (out_features, in_features). The weight-specific learning rate and the initialization standard deviation are scaled with the width of the layer according to the table above.

property bias: Tensor | None

The bias of the module of shape (out_features). If bias is True, the bias-specific learning rate and the initialization standard deviation are scaled with the width of the layer according to the table above.

class cellarium.ml.layers.MultiHeadAttention(d_model: int, use_bias: bool, n_heads: int, dropout_p: float, attention_logits_scale: float, attention_backend: Literal['flex', 'math', 'mem_efficient', 'torch'], attention_softmax_fp32: bool, Wqkv_initializer: dict[str, Any], Wo_initializer: dict[str, Any])[source]

Bases: Module

Multi-head attention.

Parameters:
  • d_model (int) – Dimensionality of the embeddings and hidden states.

  • use_bias (bool) – Whether to use bias in the linear transformations.

  • n_heads (int) – Number of attention heads.

  • dropout_p (float) – Dropout probability.

  • attention_logits_scale (float) – Multiplier for the attention scores.

  • attention_backend (Literal['flex', 'math', 'mem_efficient', 'torch']) – Backend for the attention computation.

  • attention_softmax_fp32 (bool) – Whether to use float32 for softmax computation when torch backend is used.

  • Wqkv_initializer (dict[str, Any]) – Initializer for the query, key, and value linear transformations.

  • Wo_initializer (dict[str, Any]) – Initializer for the output linear transformation.

static split_heads(X_nqd: Tensor, n_heads: int) Tensor[source]

Transposition for parallel computation of multiple attention heads.

Parameters:
  • X_nqd (Tensor)

  • n_heads (int)

Return type:

Tensor

static merge_heads(X_nhqk: Tensor) Tensor[source]

Reverse of split_heads.

Parameters:

X_nhqk (Tensor)

Return type:

Tensor

forward(x_query_ncd: Tensor, x_key_ncd: Tensor, x_value_ncd: Tensor, attention_mask_ncc: Tensor | BlockMask) Tensor[source]
Parameters:
  • x_query_ncd (Tensor) – Input query tensor of shape (n, c, d).

  • x_key_ncd (Tensor) – Input key tensor of shape (n, c, d).

  • x_value_ncd (Tensor) – Input value tensor of shape (n, c, d).

  • attention_mask_ncc (Tensor | BlockMask) – Attention mask tensor of shape (n, c, c).

Returns:

The output hidden state tensor of shape (n, c, d).

Return type:

Tensor

class cellarium.ml.layers.MultiHeadReadout(categorical_vocab_sizes: dict[str, int], d_model: int, use_bias: bool, output_logits_scale: float, heads_initializer: dict[str, Any])[source]

Bases: Module

Multi-head readout.

Parameters:
  • categorical_vocab_sizes (dict[str, int]) – Categorical token vocabulary sizes.

  • d_model (int) – Dimensionality of the embeddings and hidden states.

  • use_bias (bool) – Whether to use bias in the linear transformations.

  • output_logits_scale (float) – Multiplier for the output logits.

  • heads_initializer (dict[str, Any]) – Initializer for the output linear transformations.

forward(hidden_state_ncd: Tensor) dict[str, Tensor][source]
Parameters:

hidden_state_ncd (Tensor) – Hidden state tensor of shape (n, c, d).

Returns:

Dictionary of output logits tensors of shape (n, c, vocab_size).

Return type:

dict[str, Tensor]

class cellarium.ml.layers.NormAdd(norm_shape: int, dropout_p: float, use_bias: bool)[source]

Bases: Module

Pre-norm layer where the layer normalization is applied before the sublayer.

Parameters:
  • norm_shape (int) – The shape of the layer normalization.

  • dropout_p (float) – Dropout probability.

  • use_bias (bool) – Whether to use bias in the layer normalization.

forward(hidden_state_ncd: Tensor, sublayer: Callable[[Tensor], Tensor]) Tensor[source]
Parameters:
  • hidden_state_ncd (Tensor) – Hidden state tensor of shape (n, c, d).

  • sublayer (Callable[[Tensor], Tensor]) – Sublayer function.

Returns:

The output hidden state tensor of shape (n, c, d).

Return type:

Tensor

class cellarium.ml.layers.PositionWiseFFN(d_ffn: int, d_model: int, use_bias: bool, dense1_initializer: dict[str, Any], dense2_initializer: dict[str, Any])[source]

Bases: Module

The positionwise feed-forward network.

Parameters:
  • d_ffn (int) – Dimensionality of the inner feed-forward layers.

  • d_model (int) – Dimensionality of the embeddings and hidden states.

  • use_bias (bool) – Whether to use bias in the linear transformations.

  • dense1_initializer (dict[str, Any]) – Initializer for the first dense layer.

  • dense2_initializer (dict[str, Any]) – Initializer for the second dense layer.

forward(hidden_state_ncd: Tensor) Tensor[source]
Parameters:

hidden_state_ncd (Tensor) – Hidden state tensor of shape (n, c, d).

Returns:

The output hidden state tensor of shape (n, c, d).

Return type:

Tensor

class cellarium.ml.layers.Transformer(d_model: int, d_ffn: int, use_bias: bool, n_heads: int, n_blocks: int, dropout_p: float, attention_logits_scale: float, attention_backend: Literal['flex', 'math', 'mem_efficient', 'torch'], attention_softmax_fp32: bool, Wqkv_initializer: dict[str, Any], Wo_initializer: dict[str, Any], dense1_initializer: dict[str, Any], dense2_initializer: dict[str, Any])[source]

Bases: Module

Transformer model.

Parameters:
  • d_model (int) – Dimensionality of the embeddings and hidden states.

  • d_ffn (int) – Dimensionality of the inner feed-forward layers.

  • use_bias (bool) – Whether to use bias in the linear transformations and norm-add layers.

  • n_heads (int) – Number of attention heads.

  • n_blocks (int) – Number of transformer blocks.

  • dropout_p (float) – Dropout probability.

  • attention_logits_scale (float) – Multiplier for the attention scores.

  • attention_backend (Literal['flex', 'math', 'mem_efficient', 'torch']) – Backend for the attention computation.

  • attention_softmax_fp32 (bool) – Whether to use FP32 for the softmax computation when torch backend is used.

  • Wqkv_initializer (dict[str, Any]) – Initializer for the query, key, and value linear transformations.

  • Wo_initializer (dict[str, Any]) – Initializer for the output linear transformation.

  • dense1_initializer (dict[str, Any]) – Initializer for the first dense layer.

  • dense2_initializer (dict[str, Any]) – Initializer for the second dense layer.

forward(hidden_state_ncd: Tensor, attention_mask_ncc: Tensor | BlockMask) Tensor[source]
Parameters:
  • hidden_state_ncd (Tensor) – Hidden state tensor of shape (n, c, d).

  • attention_mask_ncc (Tensor | BlockMask) – Attention mask tensor of shape (n, c, c).

Returns:

The output hidden state tensor of shape (n, c, d).

Return type:

Tensor

class cellarium.ml.layers.TransformerBlock(d_model: int, d_ffn: int, use_bias: bool, n_heads: int, dropout_p: float, attention_logits_scale: float, attention_backend: Literal['flex', 'math', 'mem_efficient', 'torch'], attention_softmax_fp32: bool, Wqkv_initializer: dict[str, Any], Wo_initializer: dict[str, Any], dense1_initializer: dict[str, Any], dense2_initializer: dict[str, Any])[source]

Bases: Module

Transformer block.

Parameters:
  • d_model (int) – Dimensionality of the embeddings and hidden states.

  • d_ffn (int) – Dimensionality of the inner feed-forward layers.

  • use_bias (bool) – Whether to use bias in the linear transformations and norm-add layers.

  • n_heads (int) – Number of attention heads.

  • dropout_p (float) – Dropout probability.

  • attention_logits_scale (float) – Multiplier for the attention scores.

  • attention_backend (Literal['flex', 'math', 'mem_efficient', 'torch']) – Backend for the attention computation.

  • attention_softmax_fp32 (bool) – Whether to use FP32 for the softmax computation when torch backend is used.

  • Wqkv_initializer (dict[str, Any]) – Initializer for the query, key, and value linear transformations.

  • Wo_initializer (dict[str, Any]) – Initializer for the output linear transformation.

  • dense1_initializer (dict[str, Any]) – Initializer for the first dense layer.

  • dense2_initializer (dict[str, Any]) – Initializer for the second dense layer.

forward(hidden_state_ncd: Tensor, attention_mask_ncc: Tensor | BlockMask) Tensor[source]
Parameters:
  • hidden_state_ncd (Tensor) – Hidden state tensor of shape (n, c, d).

  • attention_mask_ncc (Tensor | BlockMask) – Attention mask tensor of shape (n, c, c).

Returns:

The output hidden state tensor of shape (n, c, d).

Return type:

Tensor