Core
- class cellarium.ml.core.CellariumAnnDataDataModule(dadc: DistributedAnnDataCollection | AnnData, batch_keys: dict[str, dict[str, AnnDataField] | AnnDataField] | None = None, batch_size: int = 1, iteration_strategy: Literal['same_order', 'cache_efficient'] = 'cache_efficient', shuffle: bool = False, shuffle_seed: int = 0, drop_last_indices: bool = False, drop_incomplete_batch: bool = False, train_size: float | int | None = None, val_size: float | int | None = None, worker_seed: int | None = None, test_mode: bool = False, num_workers: int = 0, prefetch_factor: int | None = None, persistent_workers: bool = False)[source]
Bases:
LightningDataModule
DataModule for
IterableDistributedAnnDataCollectionDataset
.Example:
>>> from cellarium.ml import CellariumAnnDataDataModule >>> from cellarium.ml.data import DistributedAnnDataCollection >>> from cellarium.ml.utilities.data import AnnDataField, densify >>> dm = CellariumAnnDataDataModule( ... DistributedAnnDataCollection( ... "gs://bucket-name/folder/adata{000..005}.h5ad", ... shard_size=10_000, ... ), ... max_cache_size=2, ... batch_keys={ ... "x_ng": AnnDataField(attr="X", convert_fn=densify), ... "var_names_g": AnnDataField(attr="var_names"), ... }, ... batch_size=5000, ... iteration_strategy="cache_efficient", ... shuffle=True, ... shuffle_seed=0, ... drop_last_indices=True, ... num_workers=4, ... ) >>> dm.setup() >>> for batch in dm.train_dataloader(): ... print(batch.keys()) # x_ng, var_names_g
- Parameters:
dadc (DistributedAnnDataCollection | AnnData) – An instance of
DistributedAnnDataCollection
orAnnData
.batch_keys (dict[str, dict[str, AnnDataField] | AnnDataField] | None) – Dictionary that specifies which attributes and keys of the
dadc
to return in the batch data and how to convert them. Keys must correspond to the input keys of the transforms or the model. Values must be instances ofcellarium.ml.utilities.data.AnnDataField
.batch_size (int) – How many samples per batch to load.
iteration_strategy (Literal['same_order', 'cache_efficient']) – Strategy to use for iterating through the dataset. Options are
same_order
andcache_efficient
.same_order
will iterate through the dataset in the same order independent of the number of replicas and workers.cache_efficient
will try to minimize the amount of anndata files fetched by each worker.shuffle (bool) – If
True
, the data is reshuffled at every epoch.shuffle_seed (int) – Random seed used to shuffle the sampler if
shuffle=True
.drop_last_indices (bool) – If
True
, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. IfFalse
, the sampler will add extra indices to make the data evenly divisible across the replicas.drop_incomplete_batch (bool) – If
True
, the dataloader will drop the incomplete batch if the dataset size is not divisible by the batch size.train_size (float | int | None) – Size of the train split. If
float
, should be between0.0
and1.0
and represent the proportion of the dataset to include in the train split. Ifint
, represents the absolute number of train samples. IfNone
, the value is automatically set to the complement of theval_size
.val_size (float | int | None) – Size of the validation split. If
float
, should be between0.0
and1.0
and represent the proportion of the dataset to include in the validation split. Ifint
, represents the absolute number of validation samples. IfNone
, the value is set to the complement of thetrain_size
. Iftrain_size
is alsoNone
, it will be set to0
.worker_seed (int | None) – Random seed used to seed the workers. If
None
, then the workers will not be seeded. The seed of the individual worker is computed based on theworker_seed
, global worker id, and the epoch. Note that the this seed affectscpu_transforms
when they are used. When resuming training, the seed should be set to a different value to ensure that the workers are not seeded with the same seed as the previous run.test_mode (bool) – If
True
enables tracking of cache and worker informations.num_workers (int) – How many subprocesses to use for data loading.
0
means that the data will be loaded in the main process.prefetch_factor (int | None) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of
num_workers=0
default isNone
. Otherwise, if value ofnum_workers > 0
default is2
)persistent_workers (bool) – If
True
, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workersDataset
instances alive.
- class cellarium.ml.core.CellariumModule(cpu_transforms: Iterable[Module] | None = None, transforms: Iterable[Module] | None = None, model: CellariumModel | None = None, optim_fn: type[Optimizer] | None = None, optim_kwargs: dict[str, Any] | None = None, scheduler_fn: type[LRScheduler] | None = None, scheduler_kwargs: dict[str, Any] | None = None, is_initialized: bool = False)[source]
Bases:
LightningModule
CellariumModule
organizes code into following sections:cpu_transforms
: A list of transforms to apply to the input data as part of the dataloader on CPU.transforms
: A list of transforms to apply to the input data before passing it to the model.module_pipeline
: ACellariumPipeline
to apply all transforms, minus the CPU transforms if they are handled by aCellariumAnnDataDataModule
, and the model.model
: ACellariumModel
to train withtraining_step()
method and epoch end hooks.optim_fn
andoptim_kwargs
: A Pytorch optimizer class and its keyword arguments.scheduler_fn
andscheduler_kwargs
: A Pytorch lr scheduler class and its keyword arguments.
- Parameters:
cpu_transforms (Iterable[Module] | None) – A list of transforms to apply to the input data as part of the dataloader on CPU. These transforms get applied before other
transforms
. IfNone
, no transforms are applied as part of the dataloader.transforms (Iterable[Module] | None) – A list of transforms to apply to the input data before passing it to the model. If
None
, no transforms are applied.model (CellariumModel | None) – A
CellariumModel
to train.optim_fn (type[Optimizer] | None) – A Pytorch optimizer class, e.g.,
Adam
. IfNone
, no optimizer is used.optim_kwargs (dict[str, Any] | None) – Keyword arguments for optimiser.
scheduler_fn (type[LRScheduler] | None) – A Pytorch lr scheduler class, e.g.,
CosineAnnealingLR
.scheduler_kwargs (dict[str, Any] | None) – Keyword arguments for lr scheduler.
is_initialized (bool) – Whether the model has been initialized. This is set to
False
by default under the assumption thattorch.device("meta")
context was used and is set toTrue
after the first call toconfigure_model()
.
- configure_model() None [source]
Note
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent, i.e., after the first time the hook is called, subsequent calls to it should be a no-op.
Steps involved in configuring the model:
Freeze the transforms if they are instances of
CellariumModule
.Make a copy of modules on the
meta
device and assign tohparams
.Send the original modules to the host device and add to
pipeline
.Reset the model parameters if it has not been initialized before.
Assemble the full pipeline by concatenating the CPU transforms, transforms, and the model.
If the training is handled by the
pl.Trainer
and the dataloader is an instance ofCellariumAnnDataDataModule
, then the CPU transforms are dispatched to the dataloader’scollate_fn
and themodule_pipeline
calls only the (GPU) transforms and the model. Otherwise, themodule_pipeline
calls the full pipeline.
For more context, see discussions in https://dev-discuss.pytorch.org/t/state-of-model-creation-initialization-seralization-in-pytorch-core/1240
Benefits of this approach:
The checkpoint stores modules on the meta device.
Loading from a checkpoint skips a wasteful step of initializing module parameters before loading the
state_dict
.The module parameters are directly initialized on the host gpu device instead of being initialized on the cpu and then moved to the gpu device (given that modules were instantiated under the
torch.device("meta")
context).
- Return type:
None
- property model: CellariumModel
The model
- property transforms: CellariumPipeline
The transforms pipeline
- property cpu_transforms: CellariumPipeline
The CPU transforms pipeline to be applied by the dataloader
- property module_pipeline: CellariumPipeline
The pipeline applied by
training_step()
,validation_step()
, andforward()
- training_step(batch: dict[str, dict[str, ndarray | Tensor] | ndarray | Tensor], batch_idx: int) Tensor | None [source]
Forward pass for training step.
- Parameters:
batch (dict[str, dict[str, ndarray | Tensor] | ndarray | Tensor]) – A dictionary containing the batch data.
batch_idx (int) – The index of the batch.
- Returns:
Loss tensor or
None
if no loss.- Return type:
Tensor | None
- forward(batch: dict[str, dict[str, ndarray | Tensor] | ndarray | Tensor]) dict[str, dict[str, ndarray | Tensor] | ndarray | Tensor] [source]
Forward pass for inference step.
- Parameters:
batch (dict[str, dict[str, ndarray | Tensor] | ndarray | Tensor]) – A dictionary containing the batch data.
- Returns:
A dictionary containing the batch data and inference outputs.
- Return type:
dict[str, dict[str, ndarray | Tensor] | ndarray | Tensor]
- validation_step(batch: dict[str, Any], batch_idx: int) None [source]
Forward pass for validation step.
- Parameters:
batch (dict[str, Any]) – A dictionary containing the batch data.
batch_idx (int) – The index of the batch.
- Returns:
None
- Return type:
None
- configure_optimizers() OptimizerLRSchedulerConfig | None [source]
Configure optimizers for the model.
- Return type:
OptimizerLRSchedulerConfig | None
- on_train_epoch_start() None [source]
Calls the
set_epoch
method on the iterable dataset of the given dataloader.If the dataset is
IterableDataset
and hasset_epoch
method defined, thenset_epoch
must be called at the beginning of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.- Return type:
None
- on_train_start() None [source]
Calls the
on_train_start
method on themodel
attribute. If themodel
attribute hason_train_start
method defined, thenon_train_start
must be called at the beginning of training.- Return type:
None
- on_train_epoch_end() None [source]
Calls the
set_resume_step
method on the iterable dataset of the given dataloader.If the dataset is
IterableDataset
and hasset_resume_step
method defined, thenset_resume_step
must be called at the end of every epoch to ensure that the dataset is in the correct state for resuming training.Calls the
on_train_epoch_end
method on themodel
attribute. If themodel
attribute hason_train_epoch_end
method defined, thenon_train_epoch_end
must be called at the end of every epoch.- Return type:
None
- class cellarium.ml.core.CellariumPipeline(modules: Iterable[Module] | None = None)[source]
Bases:
ModuleList
A pipeline of modules. Modules are expected to return a dictionary. The input dictionary is sequentially passed to (piped through) each module and updated with its output dictionary.
When used within
cellarium.ml.core.CellariumModule
, the last module in the pipeline is expected to be a model (cellarium.ml.models.CellariumModel
) and any preceding modules are expected to be data transforms.Example
>>> from cellarium.ml import CellariumPipeline >>> from cellarium.ml.transforms import NormalizeTotal, Log1p >>> from cellarium.ml.models import IncrementalPCA >>> pipeline = CellariumPipeline([ ... NormalizeTotal(), ... Log1p(), ... IncrementalPCA(var_names_g=[f"gene_{i}" for i in range(20)], n_components=10), ... ]) >>> batch = {"x_ng": x_ng, "total_mrna_umis_n": total_mrna_umis_n, "var_names_g": var_names_g} >>> output = pipeline(batch) # or pipeline.predict(batch)
- Parameters:
modules (Iterable[Module] | None) – Modules to be executed sequentially.