Callbacks
- class cellarium.ml.callbacks.LossScaleMonitor[source]
Bases:
Callback
A callback that logs the loss scale during mixed-precision training.
- class cellarium.ml.callbacks.PredictionWriter(output_dir: Path | str, prediction_size: int | None = None)[source]
Bases:
BasePredictionWriter
Write predictions to a CSV file. The CSV file will have the same number of rows as the number of predictions, and the number of columns will be the same as the prediction size. The first column will be the ID of each cell.
Note
To prevent an out-of-memory error, set the
return_predictions
argument of theTrainer
toFalse
.- Parameters:
output_dir (Path | str) – The directory to write the predictions to.
prediction_size (int | None) – The size of the prediction. If
None
, the entire prediction will be written. If notNone
, only the firstprediction_size
columns will be written.
- class cellarium.ml.callbacks.VarianceMonitor(total_variance: float | None = None)[source]
Bases:
Callback
Automatically monitors and logs explained variance by the model during training.
- Parameters:
total_variance (float | None) – Total variance of the data. Used to calculate the explained variance ratio.
- on_train_start(trainer: Trainer, pl_module: LightningModule) None [source]
Called when the train begins.
- Raises:
AssertionError – If
pl_module.model
is not aProbabilisticPCA
instance.MisconfigurationException – If
Trainer
has nologger
.
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- Return type:
None