openpilot/tinygrad_repo/examples/mlperf/lr_schedulers.py
Vehicle Researcher c5d5c5d1f3 openpilot v0.10.1 release
date: 2025-10-24T00:30:59
master commit: 405631baf9685e171a0dd19547cb763f1b163d18
2025-10-24 00:31:03 -07:00

39 lines
1.9 KiB
Python

import math
from tinygrad import dtypes
from tinygrad.nn.optim import Optimizer
from extra.lr_scheduler import LR_Scheduler
# https://github.com/mlcommons/training/blob/e237206991d10449d9675d95606459a3cb6c21ad/image_classification/tensorflow2/lars_util.py
class PolynomialDecayWithWarmup(LR_Scheduler):
def __init__(self, optimizer: Optimizer, initial_lr, end_lr, train_steps, warmup, power=2):
super().__init__(optimizer)
self.epoch_counter = self.epoch_counter.cast(dtypes.float32)
assert train_steps > 0 and warmup > 0
self.warmup = min(warmup, train_steps)
self.initial_lr, self.end_lr, self.epochs, self.power = initial_lr, end_lr, train_steps, power
# set lr for first warmup step
self.optimizer.lr.assign(self.get_lr()).realize()
def get_lr(self):
# LR is 0 on the first step, matching the reference.
warmup_lr = (self.epoch_counter * (1.0 / self.warmup)) * self.initial_lr
x = (1 - (self.epoch_counter - self.warmup) / (self.epochs - self.warmup + 1))
return (self.epoch_counter <= self.warmup).where(warmup_lr, (self.initial_lr - self.end_lr) * x ** self.power + self.end_lr).cast(self.optimizer.lr.dtype)
class CosineAnnealingLRWithWarmup(LR_Scheduler):
def __init__(self, optimizer:Optimizer, base_lr, end_lr, warmup_steps:int, decay_steps:int):
assert warmup_steps > 0 and decay_steps > 0
super().__init__(optimizer)
self.base_lr = base_lr
self.end_lr = end_lr
self.warmup_steps = warmup_steps
self.decay_steps = decay_steps
# set lr for first warmup step
self.optimizer.lr.assign(self.get_lr()).realize()
def get_lr(self):
warmup_lr = ((self.epoch_counter+1) / self.warmup_steps) * self.base_lr
decay_lr = self.end_lr + 0.5 * (self.base_lr-self.end_lr) * (1 + (((self.epoch_counter+1-self.warmup_steps)/self.decay_steps) * math.pi).cos())
return (self.epoch_counter < self.warmup_steps).where(warmup_lr, decay_lr).cast(self.optimizer.lr.dtype)