import torch import torch.nn.functional as F from utils import update_model_vocab def online_train_step(model, optimizer, tokenizer, message, device): # Ensure model can handle current vocab update_model_vocab(model, tokenizer) # Freeze tokenizer so it doesn't grow mid-train tokens = tokenizer.encode(message, return_tensors=True, freeze=True).to(device) if tokens.size(1) < 2: return 0.0 # Truncate long input max_len = model.pos_emb.size(1) if tokens.size(1) > max_len: tokens = tokens[:, -max_len:] x = tokens[:, :-1] y = tokens[:, 1:] # HARD STOP if y exceeds model vocab vocab_size = model.token_emb.num_embeddings assert y.max().item() < vocab_size, f"y contains token > vocab_size ({y.max().item()} >= {vocab_size})" logits = model(x) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item()