14 lines
447 B
Python
14 lines
447 B
Python
import torch
|
|
|
|
|
|
class MetaLearner:
|
|
"""Handles online, first-order meta-updates to the cortex."""
|
|
def __init__(self, model: torch.nn.Module, lr: float = 1e-4) -> None:
|
|
self.model = model
|
|
self.meta_optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
|
|
def meta_update(self, loss: torch.Tensor) -> None:
|
|
self.meta_optimizer.zero_grad()
|
|
loss.backward(retain_graph=True)
|
|
self.meta_optimizer.step()
|