Ruby/model/train.py
2025-04-24 13:23:54 -04:00

33 lines
857 B
Python

import torch
import torch.nn as nn
import random
import time
from model.brain import model, tokenizer, DEVICE, optimizer, loss_fn, daydream
_last_thought = time.time()
def train_on_message(text: str):
global _last_thought
model.train()
tokens = tokenizer.tokenize(text)
if len(tokens) < 2:
return
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
output = model(input_tensor)
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Idle dreaming every 15 seconds
now = time.time()
if now - _last_thought > 15:
for _ in range(3):
daydream()
_last_thought = now