117 lines
3.8 KiB
Python
117 lines
3.8 KiB
Python
import glob
|
|
import os
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
|
|
from models.transformer import TransformerGenerator
|
|
from models.discriminator import Discriminator
|
|
from utils.tokenizer import HybridTokenizer
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class RubyHeart:
|
|
def __init__(
|
|
self,
|
|
books_dir="data/books",
|
|
vocab_file="vocab/vocab.json",
|
|
model_file="models/best_gen.pt",
|
|
device=None,
|
|
):
|
|
self.device = device or torch.device(
|
|
"cuda" if torch.cuda.is_available() else "cpu"
|
|
)
|
|
# tokenizer & vocab
|
|
self.tokenizer = HybridTokenizer(vocab_file)
|
|
if not self.tokenizer.char_to_id:
|
|
self._build_vocab(books_dir)
|
|
|
|
# model init
|
|
vs = (
|
|
len(self.tokenizer.word_to_id)
|
|
+ len(self.tokenizer.char_to_id)
|
|
)
|
|
self.model = TransformerGenerator(
|
|
vocab_size=vs,
|
|
embed_dim=256,
|
|
num_heads=8,
|
|
mlp_dim=512,
|
|
num_layers=4,
|
|
max_seq_len=128,
|
|
).to(self.device)
|
|
|
|
self.model_file = model_file
|
|
self._load_checkpoint(model_file)
|
|
|
|
# online trainer
|
|
self.trainer = self._make_trainer()
|
|
|
|
def _build_vocab(self, books_dir):
|
|
paths = glob.glob(os.path.join(books_dir, "*.txt"))
|
|
texts = [open(p, encoding="utf-8").read() for p in paths]
|
|
self.tokenizer.build_vocab(texts)
|
|
|
|
def _load_checkpoint(self, path):
|
|
if os.path.isfile(path):
|
|
state = torch.load(path, map_location=self.device,
|
|
weights_only=True)
|
|
self.model.load_state_dict(state)
|
|
# else: start from scratch
|
|
|
|
def _make_trainer(self, lr=1e-5):
|
|
opt = optim.Adam(self.model.parameters(), lr=lr)
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
return {"opt": opt, "loss": loss_fn}
|
|
|
|
@staticmethod
|
|
def _top_k_top_p(logits, top_k=50, top_p=0.9):
|
|
# (same filtering code as before)
|
|
if top_k > 0:
|
|
kth = torch.topk(logits, top_k)[0][..., -1, None]
|
|
logits = torch.where(
|
|
logits < kth, float("-inf"), logits
|
|
)
|
|
if top_p > 0.0:
|
|
sorted_logits, indices = torch.sort(
|
|
logits, descending=True
|
|
)
|
|
cum_probs = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
|
|
mask = cum_probs > top_p
|
|
mask[..., 1:] = mask[..., :-1].clone()
|
|
mask[..., 0] = False
|
|
remove = indices[mask]
|
|
logits[remove] = float("-inf")
|
|
return logits
|
|
|
|
def generate(self, prompt, max_len=64, temp=1.0, top_k=50, top_p=0.9):
|
|
self.model.eval()
|
|
ids = self.tokenizer.encode(prompt)
|
|
input_ids = torch.tensor([ids], device=self.device)
|
|
with torch.no_grad():
|
|
for _ in range(max_len):
|
|
logits = self.model(input_ids)[:, -1, :] / temp
|
|
filt = self._top_k_top_p(logits, top_k, top_p)
|
|
probs = F.softmax(filt, dim=-1)
|
|
nxt = torch.multinomial(probs, 1)
|
|
input_ids = torch.cat([input_ids, nxt], dim=-1)
|
|
return self.tokenizer.decode(input_ids[0].cpu().tolist())
|
|
|
|
def train_on(self, text):
|
|
ids = self.tokenizer.encode(text)
|
|
if len(ids) < 2:
|
|
return
|
|
inp = torch.tensor([ids[:-1]], device=self.device)
|
|
tgt = torch.tensor([ids[1:]], device=self.device)
|
|
self.model.train()
|
|
self.trainer["opt"].zero_grad()
|
|
logits = self.model(inp)
|
|
loss = self.trainer["loss"](
|
|
logits.view(-1, logits.size(-1)),
|
|
tgt.view(-1),
|
|
)
|
|
loss.backward()
|
|
self.trainer["opt"].step()
|
|
torch.save(self.model.state_dict(), self.model_file)
|
|
self.model.eval()
|