# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
import torch
from torch.optim.lr_scheduler import LambdaLR
[docs]
class LinearLR(LambdaLR):
"""
Learning rate scheduler with a learning rate that decreases linearly from the initial lr
set in the optimizer to 0, after a warmup period during which it increases linearly from 0
to the initial lr set in the optimizer.
Args:
optimizer:
The optimizer for which to schedule the learning rate.
num_warmup_steps:
The number of steps for the warmup phase.
num_training_steps:
The total number of training steps.
last_epoch:
The index of the last epoch when resuming training.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
last_epoch: int = -1,
) -> None:
self.num_warmup_steps = num_warmup_steps
self.num_training_steps = num_training_steps
super().__init__(optimizer, self._lr_lambda, last_epoch)
def _lr_lambda(self, current_step: int) -> float:
if current_step < self.num_warmup_steps:
return float(current_step) / float(max(1, self.num_warmup_steps))
return max(
0.0,
float(self.num_training_steps - current_step)
/ float(max(1, self.num_training_steps - self.num_warmup_steps)),
)