Callbacks

class cellarium.ml.callbacks.LossScaleMonitor[source]

Bases: Callback

A callback that logs the loss scale during mixed-precision training.

class cellarium.ml.callbacks.PredictionWriter(output_dir: Path | str, prediction_size: int | None = None)[source]

Bases: BasePredictionWriter

Write predictions to a CSV file. The CSV file will have the same number of rows as the number of predictions, and the number of columns will be the same as the prediction size. The first column will be the ID of each cell.

Note

To prevent an out-of-memory error, set the return_predictions argument of the Trainer to False.

Parameters:
  • output_dir (Path | str) – The directory to write the predictions to.

  • prediction_size (int | None) – The size of the prediction. If None, the entire prediction will be written. If not None, only the first prediction_size columns will be written.

class cellarium.ml.callbacks.VarianceMonitor(total_variance: float | None = None)[source]

Bases: Callback

Automatically monitors and logs explained variance by the model during training.

Parameters:

total_variance (float | None) – Total variance of the data. Used to calculate the explained variance ratio.

on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]

Called when the train begins.

Raises:
  • AssertionError – If pl_module.model is not a ProbabilisticPCA instance.

  • MisconfigurationException – If Trainer has no logger.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

Return type:

None

on_train_batch_end(trainer: Trainer, pl_module: LightningModule, *args: Any, **kwargs: Any) None[source]

Called when the train batch ends.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

  • args (Any)

  • kwargs (Any)

Return type:

None