Callbacks
- class cellarium.ml.callbacks.ComputeNorm(layer_name: str | None = None)[source]
Bases:
Callback
A callback to compute the model wise and per layer norm of the parameters and gradients.
Note
This callback does not support sharded model training.
- Parameters:
layer_name (str | None) – The name of the layer to compute the per layer norm. If
None
, the callback will compute the model wise norm only.
- class cellarium.ml.callbacks.GetCoordData(layer_name_to_multiplier_name: dict[str, str] | None = None)[source]
Bases:
Callback
A callback that records the L1 norm of the output and parameter values of each layer in the model.
- Parameters:
layer_name_to_multiplier_name (dict[str, str] | None) – A dictionary mapping layer names to their corresponding multipliers. If not provided, all layers will have a multiplier of 1.0.
- class cellarium.ml.callbacks.LossScaleMonitor[source]
Bases:
Callback
A callback that logs the loss scale during mixed-precision training.
- class cellarium.ml.callbacks.PredictionWriter(output_dir: Path | str, prediction_size: int | None = None, key: str = 'x_ng', gzip: bool = True, max_threadpool_workers: int = 8)[source]
Bases:
BasePredictionWriter
Write predictions to a CSV file. The CSV file will have the same number of rows as the number of predictions, and the number of columns will be the same as the prediction size. The first column will be the ID of each cell.
Note
To prevent an out-of-memory error, set the
return_predictions
argument of theTrainer
toFalse
. This is accomplished in the config file by includingreturn_predictions: false
at indent level 0. For example,trainer: ... model: ... data: ... return_predictions: false
- Parameters:
output_dir (Path | str) – The directory to write the predictions to.
prediction_size (int | None) – The size of the prediction. If
None
, the entire prediction will be written. If notNone
, only the firstprediction_size
columns will be written.key (str) – PredictionWriter will write this key from the output of predict().
gzip (bool) – Whether to compress the CSV file using gzip.
max_threadpool_workers (int) – The maximum number of threads to use to write the predictions using a ThreadPoolExecutor.
- class cellarium.ml.callbacks.VarianceMonitor(total_variance: float | None = None)[source]
Bases:
Callback
Automatically monitors and logs explained variance by the model during training.
- Parameters:
total_variance (float | None) – Total variance of the data. Used to calculate the explained variance ratio.
- on_train_start(trainer: Trainer, pl_module: LightningModule) None [source]
Called when the train begins.
- Raises:
AssertionError – If
pl_module.model
is not aProbabilisticPCA
instance.MisconfigurationException – If
Trainer
has nologger
.
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- Return type:
None