Core

class cellarium.ml.core.CellariumAnnDataDataModule(dadc: DistributedAnnDataCollection | AnnData, batch_keys: dict[str, AnnDataField] | None = None, batch_size: int = 1, iteration_strategy: Literal['same_order', 'cache_efficient'] = 'cache_efficient', shuffle: bool = False, seed: int = 0, drop_last_indices: bool = False, train_size: float | int | None = None, val_size: float | int | None = None, test_mode: bool = False, num_workers: int = 0, drop_last_batch: bool = False, prefetch_factor: int | None = None, persistent_workers: bool = False)[source]

Bases: LightningDataModule

DataModule for IterableDistributedAnnDataCollectionDataset.

Example:

>>> from cellarium.ml import CellariumAnnDataDataModule
>>> from cellarium.ml.data import DistributedAnnDataCollection
>>> from cellarium.ml.utilities.data import AnnDataField, densify

>>> dm = CellariumAnnDataDataModule(
...     DistributedAnnDataCollection(
...         "gs://bucket-name/folder/adata{000..005}.h5ad",
...         shard_size=10_000,
...     ),
...     max_cache_size=2,
...     batch_keys={
...         "x_ng": AnnDataField(attr="X", convert_fn=densify),
...         "var_names_g": AnnDataField(attr="var_names"),
...     },
...     batch_size=5000,
...     iteration_strategy="cache_efficient",
...     shuffle=True,
...     seed=0,
...     drop_last=True,
...     num_workers=4,
... )
>>> dm.setup()
>>> for batch in dm.train_dataloader():
...     print(batch.keys())  # x_ng, var_names_g
Parameters:
  • dadc (DistributedAnnDataCollection | AnnData) – An instance of DistributedAnnDataCollection or AnnData.

  • batch_keys (dict[str, AnnDataField] | None) – Dictionary that specifies which attributes and keys of the dadc to return in the batch data and how to convert them. Keys must correspond to the input keys of the transforms or the model. Values must be instances of cellarium.ml.utilities.data.AnnDataField.

  • batch_size (int) – How many samples per batch to load.

  • iteration_strategy (Literal['same_order', 'cache_efficient']) – Strategy to use for iterating through the dataset. Options are same_order and cache_efficient. same_order will iterate through the dataset in the same order independent of the number of replicas and workers. cache_efficient will try to minimize the amount of anndata files fetched by each worker.

  • shuffle (bool) – If True, the data is reshuffled at every epoch.

  • seed (int) – Random seed used to shuffle the sampler if shuffle=True.

  • drop_last_indices (bool) – If True, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If False, the sampler will add extra indices to make the data evenly divisible across the replicas.

  • train_size (float | int | None) – Size of the train split. If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split. If int, represents the absolute number of train samples. If None, the value is automatically set to the complement of the val_size.

  • val_size (float | int | None) – Size of the validation split. If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the validation split. If int, represents the absolute number of validation samples. If None, the value is set to the complement of the train_size. If train_size is also None, it will be set to 0.

  • test_mode (bool) – If True enables tracking of cache and worker informations.

  • num_workers (int) – How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.

  • drop_last_batch (bool) – If True, the dataloader will drop the last incomplete batch if the dataset size is not divisible by the batch size.

  • prefetch_factor (int | None) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2)

  • persistent_workers (bool) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive.

setup(stage: str | None = None) None[source]

Note

setup is called from every process across all the nodes. Setting state here is recommended.

Note

val_dataset is not shuffled and uses the same_order iteration strategy.

Parameters:

stage (str | None)

Return type:

None

train_dataloader() DataLoader[source]

Training dataloader.

Return type:

DataLoader

val_dataloader() DataLoader[source]

Validation dataloader.

Return type:

DataLoader

predict_dataloader() DataLoader[source]

Prediction dataloader.

Return type:

DataLoader

class cellarium.ml.core.CellariumModule(cpu_transforms: Iterable[Module] | None = None, transforms: Iterable[Module] | None = None, model: CellariumModel | None = None, optim_fn: type[Optimizer] | None = None, optim_kwargs: dict[str, Any] | None = None, scheduler_fn: type[LRScheduler] | None = None, scheduler_kwargs: dict[str, Any] | None = None, is_initialized: bool = False)[source]

Bases: LightningModule

CellariumModule organizes code into following sections:

  • cpu_transforms: A list of transforms to apply to the input data as part of the dataloader on CPU.

  • transforms: A list of transforms to apply to the input data before passing it to the model.

  • module_pipeline: A CellariumPipeline to apply all transforms, minus the CPU transforms if they are handled by a CellariumAnnDataDataModule, and the model.

  • model: A CellariumModel to train with training_step() method and epoch end hooks.

  • optim_fn and optim_kwargs: A Pytorch optimizer class and its keyword arguments.

  • scheduler_fn and scheduler_kwargs: A Pytorch lr scheduler class and its keyword arguments.

Parameters:
  • cpu_transforms (Iterable[Module] | None) – A list of transforms to apply to the input data as part of the dataloader on CPU. These transforms get applied before other transforms. If None, no transforms are applied as part of the dataloader.

  • transforms (Iterable[Module] | None) – A list of transforms to apply to the input data before passing it to the model. If None, no transforms are applied.

  • model (CellariumModel | None) – A CellariumModel to train.

  • optim_fn (type[Optimizer] | None) – A Pytorch optimizer class, e.g., Adam. If None, no optimizer is used.

  • optim_kwargs (dict[str, Any] | None) – Keyword arguments for optimiser.

  • scheduler_fn (type[LRScheduler] | None) – A Pytorch lr scheduler class, e.g., CosineAnnealingLR.

  • scheduler_kwargs (dict[str, Any] | None) – Keyword arguments for lr scheduler.

  • is_initialized (bool) – Whether the model has been initialized. This is set to False by default under the assumption that torch.device("meta") context was used and is set to True after the first call to configure_model().

configure_model() None[source]

Note

This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent, i.e., after the first time the hook is called, subsequent calls to it should be a no-op.

Steps involved in configuring the model:

  1. Freeze the transforms if they are instances of CellariumModule.

  2. Make a copy of modules on the meta device and assign to hparams.

  3. Send the original modules to the host device and add to pipeline.

  4. Reset the model parameters if it has not been initialized before.

  5. Assemble the full pipeline by concatenating the CPU transforms, transforms, and the model.

  6. If the training is handled by the pl.Trainer and the dataloader is an instance of CellariumAnnDataDataModule, then the CPU transforms are dispatched to the dataloader’s collate_fn and the module_pipeline calls only the (GPU) transforms and the model. Otherwise, the module_pipeline calls the full pipeline.

For more context, see discussions in https://dev-discuss.pytorch.org/t/state-of-model-creation-initialization-seralization-in-pytorch-core/1240

Benefits of this approach:

  1. The checkpoint stores modules on the meta device.

  2. Loading from a checkpoint skips a wasteful step of initializing module parameters before loading the state_dict.

  3. The module parameters are directly initialized on the host gpu device instead of being initialized on the cpu and then moved to the gpu device (given that modules were instantiated under the torch.device("meta") context).

Return type:

None

property model: CellariumModel

The model

property transforms: CellariumPipeline

The transforms pipeline

property cpu_transforms: CellariumPipeline

The CPU transforms pipeline to be applied by the dataloader

property module_pipeline: CellariumPipeline

The pipeline applied by training_step(), validation_step(), and forward()

training_step(batch: dict[str, ndarray | Tensor], batch_idx: int) Tensor | None[source]

Forward pass for training step.

Parameters:
  • batch (dict[str, ndarray | Tensor]) – A dictionary containing the batch data.

  • batch_idx (int) – The index of the batch.

Returns:

Loss tensor or None if no loss.

Return type:

Tensor | None

forward(batch: dict[str, ndarray | Tensor]) dict[str, ndarray | Tensor][source]

Forward pass for inference step.

Parameters:

batch (dict[str, ndarray | Tensor]) – A dictionary containing the batch data.

Returns:

A dictionary containing the batch data and inference outputs.

Return type:

dict[str, ndarray | Tensor]

validation_step(batch: dict[str, Any], batch_idx: int) None[source]

Forward pass for validation step.

Parameters:
  • batch (dict[str, Any]) – A dictionary containing the batch data.

  • batch_idx (int) – The index of the batch.

Returns:

None

Return type:

None

configure_optimizers() OptimizerLRSchedulerConfig | None[source]

Configure optimizers for the model.

Return type:

OptimizerLRSchedulerConfig | None

on_train_epoch_start() None[source]

Calls the set_epoch method on the iterable dataset of the given dataloader.

If the dataset is IterableDataset and has set_epoch method defined, then set_epoch must be called at the beginning of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.

Return type:

None

on_train_start() None[source]

Calls the on_train_start method on the model attribute. If the model attribute has on_train_start method defined, then on_train_start must be called at the beginning of training.

Return type:

None

on_train_epoch_end() None[source]

Calls the on_train_epoch_end method on the model attribute. If the model attribute has on_train_epoch_end method defined, then on_train_epoch_end must be called at the end of every epoch.

Return type:

None

on_train_batch_end(outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int) None[source]

Calls the on_train_batch_end method on the module.

Parameters:
  • outputs (Tensor | Mapping[str, Any] | None)

  • batch (Any)

  • batch_idx (int)

Return type:

None

class cellarium.ml.core.CellariumPipeline(modules: Iterable[Module] | None = None)[source]

Bases: ModuleList

A pipeline of modules. Modules are expected to return a dictionary. The input dictionary is sequentially passed to (piped through) each module and updated with its output dictionary.

When used within cellarium.ml.core.CellariumModule, the last module in the pipeline is expected to be a model (cellarium.ml.models.CellariumModel) and any preceding modules are expected to be data transforms.

Example

>>> from cellarium.ml import CellariumPipeline
>>> from cellarium.ml.transforms import NormalizeTotal, Log1p
>>> from cellarium.ml.models import IncrementalPCA
>>> pipeline = CellariumPipeline([
...     NormalizeTotal(),
...     Log1p(),
...     IncrementalPCA(var_names_g=[f"gene_{i}" for i in range(20)], n_components=10),
... ])
>>> batch = {"x_ng": x_ng, "total_mrna_umis_n": total_mrna_umis_n, "var_names_g": var_names_g}
>>> output = pipeline(batch)  # or pipeline.predict(batch)
Parameters:

modules (Iterable[Module] | None) – Modules to be executed sequentially.