20 lines
574 B
Python
20 lines
574 B
Python
import torch
|
|
from model.brain import model, optimizer, loss_fn, tokenizer, DEVICE
|
|
|
|
|
|
def train_on_message(text: str):
|
|
model.train()
|
|
tokens = tokenizer.tokenize(text)
|
|
if len(tokens) < 2:
|
|
return
|
|
|
|
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
|
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
|
|
|
output = model(input_tensor)
|
|
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|