Ruby/model/train.py
2025-04-24 13:17:08 -04:00

20 lines
574 B
Python

import torch
from model.brain import model, optimizer, loss_fn, tokenizer, DEVICE
def train_on_message(text: str):
model.train()
tokens = tokenizer.tokenize(text)
if len(tokens) < 2:
return
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
output = model(input_tensor)
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()