Source code for cellarium.ml.models.mu_linear

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Callable
from dataclasses import dataclass
from typing import Literal

import torch
import torch.nn as nn


[docs] @dataclass class abcdParameter: """ An *abcd*-parametrization describes the scaling of a parameter :math:`W` with width :math:`n`. The parameter is initialized with a standard deviation :math:`\\sigma` and a parameter-specific learning rate scaling factor :math:`\\alpha` at base width :math:`n_0`. The scaling of the parameterization with width :math:`n` is described by a set of numbers :math:`\\{a, b, c, d\\}` such that: a. Parameter is given as :math:`W = \\sqrt{\\alpha} \\cdot (n_0 / n)^a \\cdot w` where :math:`w` is the learnable parameter. b. Learnable parameter is initialized as :math:`w \\sim \\mathcal{N}(0, \\sigma \\cdot (n_0 / n)^b / \\sqrt{\\alpha})`. c. The effective learning rate for :math:`W` is :math:`\\alpha \\cdot (n_0 / n)^{2a} \\cdot (n_0 / n)^c \\cdot \\eta` for some global learning rate :math:`\\eta`. In this implementation, :math:`c` is equal to :math:`0`. d. The gradients of :math:`w` are scaled by :math:`(n / n_0)^d`. Args: data: The tensor data. width: The width :math:`n` of the tensor. a: The :math:`a` parameter. b: The :math:`b` parameter. d: The :math:`d` parameter. init_std: The initialization standard deviation :math:`\\sigma` at base width :math:`n_0`. lr_scale: The learning rate scale factor :math:`\\alpha` at base width :math:`n_0`. base_width: The base width :math:`n_0`. """ data: torch.Tensor | None = None width: int = 1 a: float = 0.0 b: float = 0.0 d: float = 0.0 init_std: float = 1.0 lr_scale: float = 1.0 base_width: int = 1
[docs] class MuLinear(nn.Module): r""" Linear layer with a maximal update parametrization. The maximal update parametrization for SGD is defined by: +-----------+----------------+----------------+----------------+ | | Input & Biases | Hidden | Output | +-----------+----------------+----------------+----------------+ | :math:`a` | :math:`-0.5` | :math:`0` | :math:`0.5` | +-----------+----------------+----------------+----------------+ | :math:`b` | :math:`0.5` | :math:`0.5` | :math:`0.5` | +-----------+----------------+----------------+----------------+ | :math:`c` | :math:`0` | :math:`0` | :math:`0` | +-----------+----------------+----------------+----------------+ | :math:`d` | :math:`0` | :math:`0` | :math:`0` | +-----------+----------------+----------------+----------------+ | :math:`n` | out_features | in_features | in_features | +-----------+----------------+----------------+----------------+ The maximal update parametrization for Adam and AdamW is defined by +-----------+----------------+----------------+----------------+ | | Input & Biases | Hidden | Output | +-----------+----------------+----------------+----------------+ | :math:`a` | :math:`0` | :math:`1` | :math:`1` | +-----------+----------------+----------------+----------------+ | :math:`b` | :math:`0` | :math:`-0.5` | :math:`0` | +-----------+----------------+----------------+----------------+ | :math:`c` | :math:`0` | :math:`0` | :math:`0` | +-----------+----------------+----------------+----------------+ | :math:`d` | :math:`1` | :math:`2` | :math:`1` | +-----------+----------------+----------------+----------------+ | :math:`n` | out_features | in_features | in_features | +-----------+----------------+----------------+----------------+ Since in this implementation :math:`c` always equals 0, regular PyTorch optimizers can be used. **References:** 1. `Feature Learning in Infinite-Width Neural Networks (Yang et al.) <https://arxiv.org/pdf/2011.14522.pdf>`_. 2. `Tensor Programs IVb: Adaptive Optimization in the ∞-Width Limit (Yang et al.) <https://arxiv.org/pdf/2308.01814.pdf>`_. Args: in_features: Size of each input sample. out_features: Size of each output sample. bias: If set to ``False``, the layer will not learn an additive bias. layer: Layer type. optimizer: Optimizer type. weight_init_std: The standard deviation of the weight initialization at base width. bias_init_std: The standard deviation of the bias initialization at base width. lr_scale: The learning rate scaling factor for the weight and the bias. base_width: The base width of the layer. """ def __init__( self, in_features: int, out_features: int, bias: bool, layer: Literal["input", "hidden", "output"], optimizer: Literal["sgd", "adam", "adamw"], weight_init_std: float = 1.0, bias_init_std: float = 0.0, lr_scale: float = 1.0, base_width: int = 1, ) -> None: super().__init__() # inf dim if layer == "output": bias_width = base_width else: bias_width = out_features if layer == "input": weight_width = out_features else: weight_width = in_features # c = 0 for all layers by design # a, b, d if optimizer == "sgd": # bias scaling: Θ(1) bias_a = -0.5 bias_b = 0.5 bias_d = 0.0 # weight if layer == "input": # scaling: Θ(1) weight_a = -0.5 weight_b = 0.5 weight_d = 0.0 elif layer == "hidden": # scaling: Θ(1 / sqrt(n)) weight_a = 0.0 weight_b = 0.5 weight_d = 0.0 elif layer == "output": # scaling: Θ(1 / n) weight_a = 0.5 weight_b = 0.5 weight_d = 0.0 elif optimizer in ["adam", "adamw"]: # bias scaling: Θ(1) bias_a = 0.0 bias_b = 0.0 bias_d = 1.0 # weight if layer == "input": # scaling: Θ(1) weight_a = 0.0 weight_b = 0.0 weight_d = 1.0 elif layer == "hidden": # scaling: Θ(1 / sqrt(n)) weight_a = 1.0 weight_b = -0.5 weight_d = 2.0 elif layer == "output": # scaling: Θ(1 / n) weight_a = 1.0 weight_b = 0.0 weight_d = 1.0 else: raise ValueError(f"Optimizer must be either 'sgd', 'adam', or 'adamw'. Got {optimizer!r}") self.in_features = in_features self.out_features = out_features self.layer = layer self.optimizer = optimizer self.weight_init_std = weight_init_std self.bias_init_std = bias_init_std self.lr_scale = lr_scale self.base_width = base_width self.weight = abcdParameter( # type: ignore[assignment] torch.empty(out_features, in_features), width=weight_width, a=weight_a, b=weight_b, d=weight_d, init_std=weight_init_std, lr_scale=lr_scale, base_width=base_width, ) if bias: self.bias = abcdParameter( # type: ignore[assignment] torch.empty(out_features), width=bias_width, a=bias_a, b=bias_b, d=bias_d, init_std=bias_init_std, lr_scale=lr_scale, base_width=base_width, ) else: self.bias = abcdParameter(None) # type: ignore[assignment] @staticmethod def scale_grad(base_width: int, width: int, d: float) -> Callable[[torch.Tensor], torch.Tensor]: def hook(grad: torch.Tensor) -> torch.Tensor: if d == 0: return grad return grad * (width / base_width) ** d return hook @property def weight(self) -> torch.Tensor: """ The weights of the module of shape ``(out_features, in_features)``. The weight-specific learning rate and the initialization standard deviation are scaled with the width of the layer according to the table above. """ return self.weight_multiplier * self.weight_unscaled @weight.setter def weight(self, value: abcdParameter) -> None: assert value.data is not None self.weight_multiplier = value.lr_scale**0.5 * (value.base_width / value.width) ** value.a self.weight_unscaled = nn.Parameter(value.data) if value.init_std == 0: self.weight_unscaled.data.zero_() else: std = value.init_std * (value.base_width / value.width) ** value.b / value.lr_scale**0.5 self.weight_unscaled.data.normal_(mean=0.0, std=std) self.weight_unscaled.register_hook(self.scale_grad(value.base_width, value.width, value.d)) @property def bias(self) -> torch.Tensor | None: """ The bias of the module of shape ``(out_features)``. If :attr:`bias` is ``True``, the bias-specific learning rate and the initialization standard deviation are scaled with the width of the layer according to the table above. """ if self.bias_unscaled is None: return None return self.bias_multiplier * self.bias_unscaled @bias.setter def bias(self, value: abcdParameter) -> None: self.bias_unscaled: nn.Parameter | None if value.data is None: self.register_parameter("bias_unscaled", None) return self.bias_multiplier = value.lr_scale**0.5 * (value.base_width / value.width) ** value.a self.bias_unscaled = nn.Parameter(value.data) if value.init_std == 0: self.bias_unscaled.data.zero_() else: std = value.init_std * (value.base_width / value.width) ** value.b / value.lr_scale**0.5 self.bias_unscaled.data.normal_(mean=0.0, std=std) self.bias_unscaled.register_hook(self.scale_grad(value.base_width, value.width, value.d)) def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.nn.functional.linear(input, self.weight, self.bias) def extra_repr(self) -> str: return ( f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" f", layer={self.layer}, optimizer={self.optimizer}, weight_init_std={self.weight_init_std}" f", bias_init_std={self.bias_init_std}, lr_scale={self.lr_scale}, base_width={self.base_width}" )