46 lines
1.2 KiB
Python
46 lines
1.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import random
|
|
import time
|
|
from model.brain import model, tokenizer, DEVICE, optimizer, loss_fn, daydream
|
|
from context.context import get_recent_context, add_to_context
|
|
|
|
_last_thought = time.time()
|
|
|
|
LOSS_FILE = "data/logs/loss.log"
|
|
|
|
|
|
def log_loss(value: float):
|
|
with open(LOSS_FILE, "a", encoding="utf-8") as f:
|
|
f.write(f"{time.time()},{round(value, 4)}\n")
|
|
|
|
|
|
def train_on_message(text: str):
|
|
global _last_thought
|
|
model.train()
|
|
context_texts = get_recent_context(3)
|
|
augmented_text = " ".join(context_texts + [text])
|
|
tokens = tokenizer.tokenize(augmented_text)
|
|
|
|
if len(tokens) < 2:
|
|
return
|
|
|
|
input_tensor = torch.tensor(tokens[:-1], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
|
target_tensor = torch.tensor(tokens[1:], dtype=torch.long, device=DEVICE).unsqueeze(0)
|
|
|
|
output = model(input_tensor)
|
|
loss = loss_fn(output.view(-1, output.size(-1)), target_tensor.view(-1))
|
|
log_loss(loss.item())
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
add_to_context(text)
|
|
|
|
now = time.time()
|
|
if now - _last_thought > 15:
|
|
for _ in range(3):
|
|
daydream()
|
|
_last_thought = now
|