Ruby/nervous_system/meta_learning.py

14 lines
447 B
Python

import torch
class MetaLearner:
"""Handles online, first-order meta-updates to the cortex."""
def __init__(self, model: torch.nn.Module, lr: float = 1e-4) -> None:
self.model = model
self.meta_optimizer = torch.optim.Adam(model.parameters(), lr=lr)
def meta_update(self, loss: torch.Tensor) -> None:
self.meta_optimizer.zero_grad()
loss.backward(retain_graph=True)
self.meta_optimizer.step()