Callbacks
- class cellarium.ml.callbacks.ComputeNorm(layer_name: str | None = None)[source]
Bases:
CallbackA callback to compute the model wise and per layer l2 norm of the parameters and gradients.
- Parameters:
layer_name (str | None) – The name of the layer to compute the per layer norm. If
None, the callback will compute the model wise norm only.
- class cellarium.ml.callbacks.GetCoordData(layer_name_to_multiplier_name: dict[str, str] | None = None)[source]
Bases:
CallbackA callback that records the L1 norm of the output and parameter values of each layer in the model.
- Parameters:
layer_name_to_multiplier_name (dict[str, str] | None) – A dictionary mapping layer names to their corresponding multipliers. If not provided, all layers will have a multiplier of 1.0.
- class cellarium.ml.callbacks.LossScaleMonitor[source]
Bases:
CallbackA callback that logs the loss scale during mixed-precision training.
- class cellarium.ml.callbacks.PredictionWriter(output_dir: Path | str, prediction_size: int | None = None, key: str = 'x_ng', gzip: bool = True, max_threadpool_workers: int = 8)[source]
Bases:
BasePredictionWriterWrite predictions to a CSV file. The CSV file will have the same number of rows as the number of predictions, and the number of columns will be the same as the prediction size. The first column will be the ID of each cell.
Note
To prevent an out-of-memory error, set the
return_predictionsargument of theTrainertoFalse. This is accomplished in the config file by includingreturn_predictions: falseat indent level 0. For example,trainer: ... model: ... data: ... return_predictions: false
- Parameters:
output_dir (Path | str) – The directory to write the predictions to.
prediction_size (int | None) – The size of the prediction. If
None, the entire prediction will be written. If notNone, only the firstprediction_sizecolumns will be written.key (str) – PredictionWriter will write this key from the output of predict().
gzip (bool) – Whether to compress the CSV file using gzip.
max_threadpool_workers (int) – The maximum number of threads to use to write the predictions using a ThreadPoolExecutor.
- class cellarium.ml.callbacks.VarianceMonitor(total_variance: float | None = None)[source]
Bases:
CallbackAutomatically monitors and logs explained variance by the model during training.
- Parameters:
total_variance (float | None) – Total variance of the data. Used to calculate the explained variance ratio.
- on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]
Called when the train begins.
- Raises:
AssertionError – If
pl_module.modelis not aProbabilisticPCAinstance.MisconfigurationException – If
Trainerhas nologger.
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- Return type:
None