35 lines
987 B
Python
35 lines
987 B
Python
import torch
|
|
import torch.nn.functional as F
|
|
from utils import update_model_vocab
|
|
|
|
|
|
def online_train_step(model, optimizer, tokenizer, message, device):
|
|
# Ensure model can handle current vocab
|
|
update_model_vocab(model, tokenizer)
|
|
|
|
# Freeze tokenizer so it doesn't grow mid-train
|
|
tokens = tokenizer.encode(message, return_tensors=True, freeze=True).to(device)
|
|
if tokens.size(1) < 2:
|
|
return 0.0
|
|
|
|
# Truncate long input
|
|
max_len = model.pos_emb.size(1)
|
|
if tokens.size(1) > max_len:
|
|
tokens = tokens[:, -max_len:]
|
|
|
|
x = tokens[:, :-1]
|
|
y = tokens[:, 1:]
|
|
|
|
# HARD STOP if y exceeds model vocab
|
|
vocab_size = model.token_emb.num_embeddings
|
|
assert y.max().item() < vocab_size, f"y contains token > vocab_size ({y.max().item()} >= {vocab_size})"
|
|
|
|
logits = model(x)
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
return loss.item()
|