import torch import time from model.brain_state import model, tokenizer, DEVICE, optimizer, loss_fn from context.context import add_to_context, get_recent_context from model.dynamic_expand import expand_model_if_needed from model.brainmap import update_brainmap LOSS_FILE = "data/logs/loss.log" def log_loss(value: float): with open(LOSS_FILE, "a", encoding="utf-8") as f: f.write(f"{time.time()},{round(value, 4)}\n") def train_on_message(text: str): expand_model_if_needed() model.train() context_texts = get_recent_context(3) augmented_text = " ".join(context_texts + [text]) tokens = tokenizer.tokenize(augmented_text) if len(tokens) < 2: return words = tokenizer.detokenize(tokens).split() update_brainmap(words) input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0) target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0) output = model(input_tensor) loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1)) optimizer.zero_grad() loss.backward() optimizer.step() log_loss(loss.item()) add_to_context(text)