Ruby/model/trainer.py
2025-04-25 23:16:18 -04:00

41 lines
1.2 KiB
Python

import torch
import time
from model.brain_state import model, tokenizer, DEVICE,loss_fn
from context.context import add_to_context, get_recent_context
from model.dynamic_expand import expand_model_if_needed, get_optimizer
from model.brainmap import update_brainmap
LOSS_FILE = "data/logs/loss.log"
def log_loss(value: float):
with open(LOSS_FILE, "a", encoding="utf-8") as f:
f.write(f"{time.time()},{round(value, 4)}\n")
def train_on_message(text: str):
expand_model_if_needed()
model.train()
context_texts = get_recent_context(3)
augmented_text = " ".join(context_texts + [text])
tokens = tokenizer.tokenize(augmented_text)
if len(tokens) < 2:
return
words = tokenizer.detokenize(tokens).split()
update_brainmap(words)
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
output = model(input_tensor)
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
opt = get_optimizer()
opt.zero_grad()
loss.backward()
opt.step()
log_loss(loss.item())
add_to_context(text)