Ruby/ruby_heart.py

117 lines
3.8 KiB
Python

import glob
import os
import torch
import torch.nn as nn
import torch.optim as optim
from models.transformer import TransformerGenerator
from models.discriminator import Discriminator
from utils.tokenizer import HybridTokenizer
import torch.nn.functional as F
class RubyHeart:
def __init__(
self,
books_dir="data/books",
vocab_file="vocab/vocab.json",
model_file="models/best_gen.pt",
device=None,
):
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# tokenizer & vocab
self.tokenizer = HybridTokenizer(vocab_file)
if not self.tokenizer.char_to_id:
self._build_vocab(books_dir)
# model init
vs = (
len(self.tokenizer.word_to_id)
+ len(self.tokenizer.char_to_id)
)
self.model = TransformerGenerator(
vocab_size=vs,
embed_dim=256,
num_heads=8,
mlp_dim=512,
num_layers=4,
max_seq_len=128,
).to(self.device)
self.model_file = model_file
self._load_checkpoint(model_file)
# online trainer
self.trainer = self._make_trainer()
def _build_vocab(self, books_dir):
paths = glob.glob(os.path.join(books_dir, "*.txt"))
texts = [open(p, encoding="utf-8").read() for p in paths]
self.tokenizer.build_vocab(texts)
def _load_checkpoint(self, path):
if os.path.isfile(path):
state = torch.load(path, map_location=self.device,
weights_only=True)
self.model.load_state_dict(state)
# else: start from scratch
def _make_trainer(self, lr=1e-5):
opt = optim.Adam(self.model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
return {"opt": opt, "loss": loss_fn}
@staticmethod
def _top_k_top_p(logits, top_k=50, top_p=0.9):
# (same filtering code as before)
if top_k > 0:
kth = torch.topk(logits, top_k)[0][..., -1, None]
logits = torch.where(
logits < kth, float("-inf"), logits
)
if top_p > 0.0:
sorted_logits, indices = torch.sort(
logits, descending=True
)
cum_probs = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
mask = cum_probs > top_p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False
remove = indices[mask]
logits[remove] = float("-inf")
return logits
def generate(self, prompt, max_len=64, temp=1.0, top_k=50, top_p=0.9):
self.model.eval()
ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([ids], device=self.device)
with torch.no_grad():
for _ in range(max_len):
logits = self.model(input_ids)[:, -1, :] / temp
filt = self._top_k_top_p(logits, top_k, top_p)
probs = F.softmax(filt, dim=-1)
nxt = torch.multinomial(probs, 1)
input_ids = torch.cat([input_ids, nxt], dim=-1)
return self.tokenizer.decode(input_ids[0].cpu().tolist())
def train_on(self, text):
ids = self.tokenizer.encode(text)
if len(ids) < 2:
return
inp = torch.tensor([ids[:-1]], device=self.device)
tgt = torch.tensor([ids[1:]], device=self.device)
self.model.train()
self.trainer["opt"].zero_grad()
logits = self.model(inp)
loss = self.trainer["loss"](
logits.view(-1, logits.size(-1)),
tgt.view(-1),
)
loss.backward()
self.trainer["opt"].step()
torch.save(self.model.state_dict(), self.model_file)
self.model.eval()