# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
import re
from collections import defaultdict
import lightning.pytorch as pl
import torch
from lightning.fabric.utilities.rank_zero import rank_zero_only
[docs]
class ComputeNorm(pl.Callback):
"""
A callback to compute the model wise and per layer norm of the parameters and gradients.
.. note::
This callback does not support sharded model training.
Args:
layer_name:
The name of the layer to compute the per layer norm.
If ``None``, the callback will compute the model wise norm only.
"""
def __init__(self, layer_name: str | None = None) -> None:
self.layer_pattern: re.Pattern[str] | None
if layer_name is not None:
self.layer_pattern = re.compile(r".*(" + layer_name + r"\.)(\d+)(\.).*")
else:
self.layer_pattern = None
[docs]
@rank_zero_only
def on_before_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule, loss: torch.Tensor) -> None:
"""Compute the model wise norm of the parameters."""
param_norm = torch.tensor(0.0, device=pl_module.device)
for _, param in pl_module.named_parameters():
if param.requires_grad:
param_norm += torch.pow(torch.norm(param.detach()), 2.0)
pl_module.log("model_wise_param_norm", torch.sqrt(param_norm).item())
[docs]
@rank_zero_only
def on_before_optimizer_step(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: torch.optim.Optimizer
) -> None:
"""Compute the model wise and per layer norm of the gradients."""
model_wise_grad_norm = torch.tensor(0.0, device=pl_module.device)
per_layer_grad_norm: dict[str, torch.Tensor] = defaultdict(lambda: torch.tensor(0.0, device=pl_module.device))
for name, param in pl_module.named_parameters():
if param.grad is None:
continue
param_grad_norm = torch.pow(torch.norm(param.grad), 2.0)
model_wise_grad_norm += param_grad_norm
# get a match if module name contains `*.layer_name.i.*` where i is layer num
if self.layer_pattern:
match = self.layer_pattern.match(name)
if match:
layer_id = match.group(2)
per_layer_grad_norm[layer_id] += param_grad_norm
pl_module.log("model_wise_grad_norm", torch.sqrt(model_wise_grad_norm).item())
if per_layer_grad_norm:
pl_module.log_dict(
{
f"per_layer_grad_norm/layer_{layer_id}": torch.sqrt(per_layer_grad_norm[layer_id]).item()
for layer_id in per_layer_grad_norm
}
)