import torch class MetaLearner: """Handles online, first-order meta-updates to the cortex.""" def __init__(self, model: torch.nn.Module, lr: float = 1e-4) -> None: self.model = model self.meta_optimizer = torch.optim.Adam(model.parameters(), lr=lr) def meta_update(self, loss: torch.Tensor) -> None: self.meta_optimizer.zero_grad() loss.backward(retain_graph=True) self.meta_optimizer.step()