# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
import os
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from queue import Queue
import lightning.pytorch as pl
import numpy as np
import pandas as pd
import torch
def write_prediction(
prediction: torch.Tensor,
obs_names_n: np.ndarray,
output_dir: Path | str,
postfix: int | str,
fields: Mapping[str, np.ndarray | torch.Tensor] | None = None,
gzip: bool = True,
executor: ThreadPoolExecutor | None = None,
) -> None:
"""
Write prediction to a CSV file.
Args:
prediction:
The prediction to write.
obs_names_n:
The IDs of the cells.
output_dir:
The directory to write the prediction to.
postfix:
A postfix to add to the CSV file name.
fields:
Additional fields to write to the CSV file. The keys of the mapping will be used as
column names, and the values will be written as columns in the CSV file. The values
must have the same number of rows as the prediction. If ``None``, no additional fields will be written.
gzip:
Whether to compress the CSV file using gzip.
executor:
The executor used to write the prediction. If ``None``, no executor will be used.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
df = pd.DataFrame(prediction.cpu())
if fields is not None:
for field_name, field_data in fields.items():
df.insert(0, field_name, field_data)
df.insert(0, "obs_names_n", obs_names_n)
output_path = os.path.join(output_dir, f"batch_{postfix}.csv" + (".gz" if gzip else ""))
to_csv_kwargs: dict[str, str | bool] = {"header": False, "index": False}
if gzip:
to_csv_kwargs |= {"compression": "gzip"}
def _write_csv(frame: pd.DataFrame, path: str) -> None:
frame.to_csv(path, **to_csv_kwargs)
if executor is None:
_write_csv(df, output_path)
else:
executor.submit(_write_csv, df, output_path)
class BoundedThreadPoolExecutor(ThreadPoolExecutor):
"""ThreadPoolExecutor with a bounded queue for task submissions.
This class is used to prevent the queue from growing indefinitely when tasks are submitted,
which can lead to an out-of-memory error.
"""
def __init__(self, max_workers: int, max_queue_size: int):
# Use a bounded queue for task submissions
self._queue: Queue = Queue(max_queue_size)
super().__init__(max_workers=max_workers)
def submit(self, fn, /, *args, **kwargs):
# Block if the queue is full to prevent task overload
self._queue.put(None)
future = super().submit(fn, *args, **kwargs)
# When the task completes, remove a marker from the queue
def done_callback(_):
self._queue.get()
future.add_done_callback(done_callback)
return future
[docs]
class PredictionWriter(pl.callbacks.BasePredictionWriter):
"""
Write predictions to a CSV file. The CSV file will have the same number of rows as the
number of predictions, and the number of columns will be the same as the prediction size.
The first column will be the ID of each cell.
.. note::
To prevent an out-of-memory error, set the ``return_predictions`` argument of the
:class:`~lightning.pytorch.Trainer` to ``False``. This is accomplished in the config
file by including ``return_predictions: false`` at indent level 0. For example,
.. code-block:: yaml
trainer:
...
model:
...
data:
...
return_predictions: false
Args:
output_dir:
The directory to write the predictions to.
prediction_size:
The size of the prediction. If ``None``, the entire prediction will be
written. If not ``None``, only the first ``prediction_size`` columns will be written.
key:
PredictionWriter will write this key from the output of `predict()`.
gzip:
Whether to compress the CSV file using gzip.
max_threadpool_workers:
The maximum number of threads to use to write the predictions using a ThreadPoolExecutor.
"""
def __init__(
self,
output_dir: Path | str,
prediction_size: int | None = None,
key: str = "x_ng",
gzip: bool = True,
max_threadpool_workers: int = 8,
field_names: Sequence[str] | None = None,
) -> None:
super().__init__(write_interval="batch")
self.output_dir = output_dir
self.prediction_size = prediction_size
self.field_names = field_names
self.key = key
self.executor = BoundedThreadPoolExecutor(
max_workers=max_threadpool_workers,
max_queue_size=max_threadpool_workers * 2,
)
self.gzip = gzip
[docs]
def __del__(self):
"""Ensure the executor shuts down on object deletion."""
self.executor.shutdown(wait=True)
def write_on_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
prediction: dict[str, torch.Tensor],
batch_indices: Sequence[int] | None,
batch: dict[str, np.ndarray | torch.Tensor],
batch_idx: int,
dataloader_idx: int,
) -> None:
if self.key not in batch.keys():
raise ValueError(
f"PredictionWriter callback specified the key '{self.key}' as the relevant output of `predict()`,"
" but the key is not present. Specify a different key as an input argument to the callback, or"
" modify the output keys of `predict()`."
)
prediction_np = prediction[self.key]
if self.prediction_size is not None:
prediction_np = prediction_np[:, : self.prediction_size]
if "obs_names_n" not in batch.keys():
raise ValueError(
"PredictionWriter callback requires the batch_key 'obs_names_n'. Add this to the YAML config."
)
assert isinstance(batch["obs_names_n"], np.ndarray)
if self.field_names is None:
fields = None
else:
fields = {field_name: batch[field_name] for field_name in self.field_names}
write_prediction(
prediction=prediction_np,
obs_names_n=batch["obs_names_n"],
output_dir=self.output_dir,
postfix=batch_idx * trainer.world_size + trainer.global_rank,
fields=fields,
gzip=self.gzip,
executor=self.executor,
)