Callbacks

class cellarium.ml.callbacks.ComputeNorm(layer_name: str | None = None)[source]

Bases: Callback

A callback to compute the model wise and per layer norm of the parameters and gradients.

Note

This callback does not support sharded model training.

Parameters:

layer_name (str | None) – The name of the layer to compute the per layer norm. If None, the callback will compute the model wise norm only.

on_before_backward(trainer: Trainer, pl_module: LightningModule, loss: Tensor) None[source]

Compute the model wise norm of the parameters.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

  • loss (Tensor)

Return type:

None

on_before_optimizer_step(trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer) None[source]

Compute the model wise and per layer norm of the gradients.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

  • optimizer (Optimizer)

Return type:

None

class cellarium.ml.callbacks.GetCoordData(layer_name_to_multiplier_name: dict[str, str] | None = None)[source]

Bases: Callback

A callback that records the L1 norm of the output and parameter values of each layer in the model.

Parameters:

layer_name_to_multiplier_name (dict[str, str] | None) – A dictionary mapping layer names to their corresponding multipliers. If not provided, all layers will have a multiplier of 1.0.

class cellarium.ml.callbacks.LossScaleMonitor[source]

Bases: Callback

A callback that logs the loss scale during mixed-precision training.

class cellarium.ml.callbacks.PredictionWriter(output_dir: Path | str, prediction_size: int | None = None, key: str = 'x_ng', gzip: bool = True, max_threadpool_workers: int = 8)[source]

Bases: BasePredictionWriter

Write predictions to a CSV file. The CSV file will have the same number of rows as the number of predictions, and the number of columns will be the same as the prediction size. The first column will be the ID of each cell.

Note

To prevent an out-of-memory error, set the return_predictions argument of the Trainer to False. This is accomplished in the config file by including return_predictions: false at indent level 0. For example,

trainer:
  ...
model:
  ...
data:
  ...
return_predictions: false
Parameters:
  • output_dir (Path | str) – The directory to write the predictions to.

  • prediction_size (int | None) – The size of the prediction. If None, the entire prediction will be written. If not None, only the first prediction_size columns will be written.

  • key (str) – PredictionWriter will write this key from the output of predict().

  • gzip (bool) – Whether to compress the CSV file using gzip.

  • max_threadpool_workers (int) – The maximum number of threads to use to write the predictions using a ThreadPoolExecutor.

__del__()[source]

Ensure the executor shuts down on object deletion.

class cellarium.ml.callbacks.VarianceMonitor(total_variance: float | None = None)[source]

Bases: Callback

Automatically monitors and logs explained variance by the model during training.

Parameters:

total_variance (float | None) – Total variance of the data. Used to calculate the explained variance ratio.

on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]

Called when the train begins.

Raises:
  • AssertionError – If pl_module.model is not a ProbabilisticPCA instance.

  • MisconfigurationException – If Trainer has no logger.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

Return type:

None

on_train_batch_end(trainer: Trainer, pl_module: LightningModule, *args: Any, **kwargs: Any) None[source]

Called when the train batch ends.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

  • args (Any)

  • kwargs (Any)

Return type:

None