Ruby/model/rehearsal.py

17 lines
456 B
Python

import torch
from model.brain import model, tokenizer, DEVICE
from model.train import train_on_message
def simulate_conversation():
seed = torch.randint(0, tokenizer.next_id, (1, 5), device=DEVICE)
output = model(seed)
preds = torch.argmax(output, dim=-1).squeeze().tolist()
if isinstance(preds, int):
preds = [preds]
text = tokenizer.detokenize(preds)
if text and len(text.split()) >= 3:
train_on_message(text)