RubyOld/train_step.py
2025-04-08 19:52:01 -04:00

35 lines
987 B
Python

import torch
import torch.nn.functional as F
from utils import update_model_vocab
def online_train_step(model, optimizer, tokenizer, message, device):
# Ensure model can handle current vocab
update_model_vocab(model, tokenizer)
# Freeze tokenizer so it doesn't grow mid-train
tokens = tokenizer.encode(message, return_tensors=True, freeze=True).to(device)
if tokens.size(1) < 2:
return 0.0
# Truncate long input
max_len = model.pos_emb.size(1)
if tokens.size(1) > max_len:
tokens = tokens[:, -max_len:]
x = tokens[:, :-1]
y = tokens[:, 1:]
# HARD STOP if y exceeds model vocab
vocab_size = model.token_emb.num_embeddings
assert y.max().item() < vocab_size, f"y contains token > vocab_size ({y.max().item()} >= {vocab_size})"
logits = model(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()