import torch import torch.nn as nn import random import time from model.brain import model, tokenizer, DEVICE, optimizer, loss_fn, daydream from context.context import get_recent_context, add_to_context _last_thought = time.time() def train_on_message(text: str): global _last_thought model.train() context_texts = get_recent_context(3) augmented_text = " ".join(context_texts + [text]) tokens = tokenizer.tokenize(augmented_text) if len(tokens) < 2: return input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0) target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0) output = model(input_tensor) loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1)) optimizer.zero_grad() loss.backward() optimizer.step() add_to_context(text) now = time.time() if now - _last_thought > 15: for _ in range(3): daydream() _last_thought = now