going back to a base state
This commit is contained in:
parent
232e62962e
commit
bf6706c72c
2
.gitignore
vendored
2
.gitignore
vendored
@ -168,6 +168,6 @@ cython_debug/
|
|||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
data/*
|
books/*
|
||||||
*.json
|
*.json
|
||||||
models/best_gen.pt
|
models/best_gen.pt
|
||||||
|
@ -1,34 +0,0 @@
|
|||||||
import random
|
|
||||||
import copy
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def mutate(model, mutation_rate=0.01):
|
|
||||||
new_model = copy.deepcopy(model)
|
|
||||||
for param in new_model.parameters():
|
|
||||||
if random.random() < mutation_rate:
|
|
||||||
noise = torch.randn_like(param) * 0.1
|
|
||||||
param.data += noise
|
|
||||||
return new_model
|
|
||||||
|
|
||||||
|
|
||||||
def crossover(parent1, parent2):
|
|
||||||
child = copy.deepcopy(parent1)
|
|
||||||
for p_child, p2 in zip(child.parameters(), parent2.parameters()):
|
|
||||||
mask = torch.rand_like(p_child) < 0.5
|
|
||||||
p_child.data[mask] = p2.data[mask]
|
|
||||||
return child
|
|
||||||
|
|
||||||
|
|
||||||
def evolve(population, fitnesses, retain_ratio=0.2, mutation_rate=0.1):
|
|
||||||
# rank by fitness (higher is better)
|
|
||||||
paired = sorted(zip(fitnesses, population), key=lambda x: x[0], reverse=True)
|
|
||||||
retain_len = int(len(paired) * retain_ratio)
|
|
||||||
parents = [ind for _, ind in paired[:retain_len]]
|
|
||||||
next_gen = parents.copy()
|
|
||||||
while len(next_gen) < len(population):
|
|
||||||
p1, p2 = random.sample(parents, 2)
|
|
||||||
child = crossover(p1, p2)
|
|
||||||
child = mutate(child, mutation_rate)
|
|
||||||
next_gen.append(child)
|
|
||||||
return next_gen
|
|
132
main.py
132
main.py
@ -1,132 +0,0 @@
|
|||||||
import os
|
|
||||||
import glob
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import discord
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from models.transformer import TransformerGenerator
|
|
||||||
from utils.tokenizer import HybridTokenizer
|
|
||||||
|
|
||||||
# ──────── Setup ────────
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
|
||||||
if not TOKEN:
|
|
||||||
raise RuntimeError("Missing DISCORD_TOKEN in .env")
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
# print(f"[INFO] Using device: {device}")
|
|
||||||
|
|
||||||
# ──────── Tokenizer & Vocab ────────
|
|
||||||
|
|
||||||
vocab_file = os.path.join("vocab", "vocab.json")
|
|
||||||
tokenizer = HybridTokenizer(vocab_file)
|
|
||||||
|
|
||||||
# If vocab.json doesn’t exist yet, build it from your books:
|
|
||||||
if not tokenizer.char_to_id:
|
|
||||||
book_paths = glob.glob(os.path.join("data", "books", "*.txt"))
|
|
||||||
texts = []
|
|
||||||
for path in book_paths:
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
texts.append(f.read())
|
|
||||||
tokenizer.build_vocab(texts)
|
|
||||||
print(f"[INFO] Built vocab ({len(tokenizer.word_to_id)} words + "
|
|
||||||
f"{len(tokenizer.char_to_id)} chars)")
|
|
||||||
|
|
||||||
# ──────── Model Setup ────────
|
|
||||||
|
|
||||||
vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id)
|
|
||||||
embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4
|
|
||||||
max_seq_len = 128
|
|
||||||
|
|
||||||
model = TransformerGenerator(
|
|
||||||
vocab_size, embed_dim, num_heads, mlp_dim, num_layers, max_seq_len
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
ckpt = os.path.join("models", "best_gen.pt")
|
|
||||||
if os.path.isfile(ckpt):
|
|
||||||
state = torch.load(ckpt, map_location=device)
|
|
||||||
model.load_state_dict(state)
|
|
||||||
print("[INFO] Loaded checkpoint models/best_gen.pt")
|
|
||||||
else:
|
|
||||||
print("[INFO] No checkpoint found; starting from random weights")
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# ──────── Online Trainer ────────
|
|
||||||
|
|
||||||
class OnlineTrainer:
|
|
||||||
"""Fine-tune the generator on each new exchange."""
|
|
||||||
|
|
||||||
def __init__(self, model, lr=1e-5):
|
|
||||||
self.model = model
|
|
||||||
self.optimizer = optim.Adam(model.parameters(), lr=lr)
|
|
||||||
self.criterion = nn.CrossEntropyLoss()
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def train_example(self, text: str):
|
|
||||||
# simple causal training: predict each next token in `text`
|
|
||||||
token_ids = tokenizer.encode(text)
|
|
||||||
if len(token_ids) < 2:
|
|
||||||
return
|
|
||||||
inp = torch.tensor([token_ids[:-1]], device=self.device)
|
|
||||||
tgt = torch.tensor([token_ids[1:]], device=self.device)
|
|
||||||
|
|
||||||
self.model.train()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
logits = self.model(inp) # (1, seq_len-1, vocab_size)
|
|
||||||
loss = self.criterion(
|
|
||||||
logits.view(-1, logits.size(-1)),
|
|
||||||
tgt.view(-1)
|
|
||||||
)
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
# persist updated weights
|
|
||||||
os.makedirs("models", exist_ok=True)
|
|
||||||
torch.save(self.model.state_dict(), ckpt)
|
|
||||||
|
|
||||||
trainer = OnlineTrainer(model)
|
|
||||||
|
|
||||||
# ──────── Discord Client ────────
|
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
|
||||||
intents.message_content = True
|
|
||||||
client = discord.Client(intents=intents)
|
|
||||||
|
|
||||||
|
|
||||||
@client.event
|
|
||||||
async def on_ready():
|
|
||||||
print(f"Ruby is online as {client.user}")
|
|
||||||
|
|
||||||
|
|
||||||
@client.event
|
|
||||||
async def on_message(message):
|
|
||||||
# ignore Ruby’s own messages
|
|
||||||
if message.author == client.user:
|
|
||||||
return
|
|
||||||
|
|
||||||
content = message.content.strip()
|
|
||||||
if not content:
|
|
||||||
return
|
|
||||||
|
|
||||||
# → Generate Ruby’s reply
|
|
||||||
ids = tokenizer.encode(content)
|
|
||||||
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
|
||||||
with torch.no_grad():
|
|
||||||
out_ids = model(inp).argmax(-1).squeeze().cpu().tolist()
|
|
||||||
reply = tokenizer.decode(out_ids)
|
|
||||||
|
|
||||||
await message.channel.send(reply)
|
|
||||||
|
|
||||||
# → Optionally train on this new example
|
|
||||||
sample = f"User: {content}\nRuby: {reply}"
|
|
||||||
trainer.train_example(sample)
|
|
||||||
|
|
||||||
# ──────── Run ────────
|
|
||||||
|
|
||||||
client.run(TOKEN)
|
|
@ -1,40 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import discord
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from ruby_heart import RubyHeart
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
|
||||||
if not TOKEN:
|
|
||||||
raise RuntimeError("DISCORD_TOKEN missing in .env")
|
|
||||||
|
|
||||||
# instantiate your “Ruby” engine
|
|
||||||
ruby = RubyHeart() # uses GPU if available
|
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
|
||||||
intents.message_content = True
|
|
||||||
client = discord.Client(intents=intents)
|
|
||||||
|
|
||||||
|
|
||||||
@client.event
|
|
||||||
async def on_ready():
|
|
||||||
print(f"Ruby is online as {client.user}")
|
|
||||||
|
|
||||||
|
|
||||||
@client.event
|
|
||||||
async def on_message(message):
|
|
||||||
if message.author == client.user:
|
|
||||||
return
|
|
||||||
content = message.content.strip()
|
|
||||||
if not content:
|
|
||||||
return
|
|
||||||
|
|
||||||
# generate + train in one call
|
|
||||||
reply = ruby.generate(content)
|
|
||||||
await message.channel.send(reply)
|
|
||||||
ruby.train_on(f"User: {content}\nRuby: {reply}")
|
|
||||||
|
|
||||||
|
|
||||||
client.run(TOKEN)
|
|
@ -1,80 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadSelfAttention(nn.Module):
|
|
||||||
def __init__(self, embed_dim, num_heads):
|
|
||||||
super().__init__()
|
|
||||||
assert embed_dim % num_heads == 0
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = embed_dim // num_heads
|
|
||||||
self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
|
|
||||||
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x: (batch, seq_len, embed_dim)
|
|
||||||
b, t, e = x.size()
|
|
||||||
qkv = self.qkv_proj(x) # (b, t, 3*e)
|
|
||||||
q, k, v = qkv.chunk(3, dim=-1)
|
|
||||||
# reshape for multi-head
|
|
||||||
q = q.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
k = k.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
v = v.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
||||||
attn = torch.softmax(attn, dim=-1)
|
|
||||||
out = torch.matmul(attn, v).transpose(1, 2).contiguous()
|
|
||||||
out = out.view(b, t, e)
|
|
||||||
return self.out_proj(out)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerBlock(nn.Module):
|
|
||||||
def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
|
|
||||||
super().__init__()
|
|
||||||
self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
|
|
||||||
self.ln1 = nn.LayerNorm(embed_dim)
|
|
||||||
self.ff = nn.Sequential(
|
|
||||||
nn.Linear(embed_dim, mlp_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(mlp_dim, embed_dim),
|
|
||||||
)
|
|
||||||
self.ln2 = nn.LayerNorm(embed_dim)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x + self.dropout(self.attn(self.ln1(x)))
|
|
||||||
x = x + self.dropout(self.ff(self.ln2(x)))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerGenerator(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
embed_dim: int,
|
|
||||||
num_heads: int,
|
|
||||||
mlp_dim: int,
|
|
||||||
num_layers: int,
|
|
||||||
max_seq_len: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.token_emb = nn.Embedding(vocab_size, embed_dim)
|
|
||||||
self.pos_emb = nn.Embedding(max_seq_len, embed_dim)
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
TransformerBlock(embed_dim, num_heads, mlp_dim)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.ln = nn.LayerNorm(embed_dim)
|
|
||||||
self.head = nn.Linear(embed_dim, vocab_size)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x: (batch, seq_len)
|
|
||||||
b, t = x.size()
|
|
||||||
positions = torch.arange(t, device=x.device).unsqueeze(0)
|
|
||||||
x = self.token_emb(x) + self.pos_emb(positions)
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x)
|
|
||||||
x = self.ln(x)
|
|
||||||
return self.head(x)
|
|
116
ruby_heart.py
116
ruby_heart.py
@ -1,116 +0,0 @@
|
|||||||
import glob
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
|
|
||||||
from models.transformer import TransformerGenerator
|
|
||||||
from models.discriminator import Discriminator
|
|
||||||
from utils.tokenizer import HybridTokenizer
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class RubyHeart:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
books_dir="data/books",
|
|
||||||
vocab_file="vocab/vocab.json",
|
|
||||||
model_file="models/best_gen.pt",
|
|
||||||
device=None,
|
|
||||||
):
|
|
||||||
self.device = device or torch.device(
|
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
)
|
|
||||||
# tokenizer & vocab
|
|
||||||
self.tokenizer = HybridTokenizer(vocab_file)
|
|
||||||
if not self.tokenizer.char_to_id:
|
|
||||||
self._build_vocab(books_dir)
|
|
||||||
|
|
||||||
# model init
|
|
||||||
vs = (
|
|
||||||
len(self.tokenizer.word_to_id)
|
|
||||||
+ len(self.tokenizer.char_to_id)
|
|
||||||
)
|
|
||||||
self.model = TransformerGenerator(
|
|
||||||
vocab_size=vs,
|
|
||||||
embed_dim=256,
|
|
||||||
num_heads=8,
|
|
||||||
mlp_dim=512,
|
|
||||||
num_layers=4,
|
|
||||||
max_seq_len=128,
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
self.model_file = model_file
|
|
||||||
self._load_checkpoint(model_file)
|
|
||||||
|
|
||||||
# online trainer
|
|
||||||
self.trainer = self._make_trainer()
|
|
||||||
|
|
||||||
def _build_vocab(self, books_dir):
|
|
||||||
paths = glob.glob(os.path.join(books_dir, "*.txt"))
|
|
||||||
texts = [open(p, encoding="utf-8").read() for p in paths]
|
|
||||||
self.tokenizer.build_vocab(texts)
|
|
||||||
|
|
||||||
def _load_checkpoint(self, path):
|
|
||||||
if os.path.isfile(path):
|
|
||||||
state = torch.load(path, map_location=self.device,
|
|
||||||
weights_only=True)
|
|
||||||
self.model.load_state_dict(state)
|
|
||||||
# else: start from scratch
|
|
||||||
|
|
||||||
def _make_trainer(self, lr=1e-5):
|
|
||||||
opt = optim.Adam(self.model.parameters(), lr=lr)
|
|
||||||
loss_fn = nn.CrossEntropyLoss()
|
|
||||||
return {"opt": opt, "loss": loss_fn}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _top_k_top_p(logits, top_k=50, top_p=0.9):
|
|
||||||
# (same filtering code as before)
|
|
||||||
if top_k > 0:
|
|
||||||
kth = torch.topk(logits, top_k)[0][..., -1, None]
|
|
||||||
logits = torch.where(
|
|
||||||
logits < kth, float("-inf"), logits
|
|
||||||
)
|
|
||||||
if top_p > 0.0:
|
|
||||||
sorted_logits, indices = torch.sort(
|
|
||||||
logits, descending=True
|
|
||||||
)
|
|
||||||
cum_probs = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
|
|
||||||
mask = cum_probs > top_p
|
|
||||||
mask[..., 1:] = mask[..., :-1].clone()
|
|
||||||
mask[..., 0] = False
|
|
||||||
remove = indices[mask]
|
|
||||||
logits[remove] = float("-inf")
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def generate(self, prompt, max_len=64, temp=1.0, top_k=50, top_p=0.9):
|
|
||||||
self.model.eval()
|
|
||||||
ids = self.tokenizer.encode(prompt)
|
|
||||||
input_ids = torch.tensor([ids], device=self.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(max_len):
|
|
||||||
logits = self.model(input_ids)[:, -1, :] / temp
|
|
||||||
filt = self._top_k_top_p(logits, top_k, top_p)
|
|
||||||
probs = F.softmax(filt, dim=-1)
|
|
||||||
nxt = torch.multinomial(probs, 1)
|
|
||||||
input_ids = torch.cat([input_ids, nxt], dim=-1)
|
|
||||||
return self.tokenizer.decode(input_ids[0].cpu().tolist())
|
|
||||||
|
|
||||||
def train_on(self, text):
|
|
||||||
ids = self.tokenizer.encode(text)
|
|
||||||
if len(ids) < 2:
|
|
||||||
return
|
|
||||||
inp = torch.tensor([ids[:-1]], device=self.device)
|
|
||||||
tgt = torch.tensor([ids[1:]], device=self.device)
|
|
||||||
self.model.train()
|
|
||||||
self.trainer["opt"].zero_grad()
|
|
||||||
logits = self.model(inp)
|
|
||||||
loss = self.trainer["loss"](
|
|
||||||
logits.view(-1, logits.size(-1)),
|
|
||||||
tgt.view(-1),
|
|
||||||
)
|
|
||||||
loss.backward()
|
|
||||||
self.trainer["opt"].step()
|
|
||||||
torch.save(self.model.state_dict(), self.model_file)
|
|
||||||
self.model.eval()
|
|
@ -1,93 +0,0 @@
|
|||||||
import glob
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
|
|
||||||
from evolution.ga import evolve
|
|
||||||
from models.transformer import TransformerGenerator
|
|
||||||
from models.discriminator import Discriminator
|
|
||||||
from utils.tokenizer import HybridTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def chunked(lst, size):
|
|
||||||
"""Yield successive chunks from a list."""
|
|
||||||
for i in range(0, len(lst), size):
|
|
||||||
yield lst[i:i + size]
|
|
||||||
|
|
||||||
|
|
||||||
def train():
|
|
||||||
vocab_file = os.path.join('vocab', 'vocab.json')
|
|
||||||
tokenizer = HybridTokenizer(vocab_file)
|
|
||||||
book_paths = glob.glob(os.path.join('data', 'books', '*.txt'))
|
|
||||||
texts = []
|
|
||||||
for path in book_paths:
|
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
|
||||||
texts.append(f.read())
|
|
||||||
|
|
||||||
if not tokenizer.char_to_id:
|
|
||||||
tokenizer.build_vocab(texts)
|
|
||||||
|
|
||||||
seq_len = 128
|
|
||||||
sequences = []
|
|
||||||
for text in texts:
|
|
||||||
token_ids = tokenizer.encode(text)
|
|
||||||
for i in range(0, len(token_ids) - seq_len, seq_len):
|
|
||||||
sequences.append(
|
|
||||||
torch.tensor(token_ids[i:i + seq_len], dtype=torch.long)
|
|
||||||
)
|
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
pop_size, generations = 10, 50
|
|
||||||
vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id)
|
|
||||||
embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4
|
|
||||||
|
|
||||||
population = [
|
|
||||||
TransformerGenerator(
|
|
||||||
vocab_size, embed_dim, num_heads, mlp_dim, num_layers, seq_len
|
|
||||||
).to(device)
|
|
||||||
for _ in range(pop_size)
|
|
||||||
]
|
|
||||||
discriminator = Discriminator(vocab_size, embed_dim).to(device)
|
|
||||||
disc_opt = optim.Adam(discriminator.parameters(), lr=1e-4)
|
|
||||||
bce = nn.BCEWithLogitsLoss()
|
|
||||||
|
|
||||||
for gen_idx in range(generations):
|
|
||||||
# Evaluate fitness
|
|
||||||
fitnesses = []
|
|
||||||
for g in population:
|
|
||||||
inp = torch.randint(0, vocab_size, (1, seq_len), device=device)
|
|
||||||
out = g(inp).argmax(-1)
|
|
||||||
score = discriminator(out)
|
|
||||||
fitnesses.append(-bce(score, torch.ones_like(score)).item())
|
|
||||||
|
|
||||||
# Train discriminator
|
|
||||||
for batch in chunked(sequences, 8):
|
|
||||||
real = torch.stack(batch).to(device)
|
|
||||||
fake_in = torch.randint(0, vocab_size, real.shape, device=device)
|
|
||||||
fake = population[0](fake_in).argmax(-1).detach()
|
|
||||||
|
|
||||||
disc_opt.zero_grad()
|
|
||||||
loss_r = bce(
|
|
||||||
discriminator(real),
|
|
||||||
torch.ones(real.size(0), 1, device=device)
|
|
||||||
)
|
|
||||||
loss_f = bce(
|
|
||||||
discriminator(fake),
|
|
||||||
torch.zeros(fake.size(0), 1, device=device)
|
|
||||||
)
|
|
||||||
(loss_r + loss_f).div_(2).backward()
|
|
||||||
disc_opt.step()
|
|
||||||
|
|
||||||
# Evolve population
|
|
||||||
population = evolve(population, fitnesses)
|
|
||||||
print(f'Gen {gen_idx:03d}: best fitness = {max(fitnesses):.4f}')
|
|
||||||
|
|
||||||
os.makedirs('models', exist_ok=True)
|
|
||||||
best = population[fitnesses.index(max(fitnesses))]
|
|
||||||
torch.save(best.state_dict(), 'models/best_gen.pt')
|
|
||||||
|
|
||||||
|
|
||||||
# kick off training immediately (no __main__ guard)
|
|
||||||
train()
|
|
@ -1,106 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import unicodedata
|
|
||||||
|
|
||||||
|
|
||||||
class HybridTokenizer:
|
|
||||||
"""Hybrid word/character tokenizer with vocab persistence."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_file,
|
|
||||||
min_word_freq=5,
|
|
||||||
max_vocab_size=10000
|
|
||||||
):
|
|
||||||
self.vocab_file = vocab_file
|
|
||||||
if os.path.exists(vocab_file):
|
|
||||||
with open(vocab_file, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
self.word_to_id = data.get('word_to_id', {})
|
|
||||||
self.char_to_id = data.get('char_to_id', {})
|
|
||||||
else:
|
|
||||||
self.word_to_id = {'<unk>': 0}
|
|
||||||
self.char_to_id = {}
|
|
||||||
self.min_word_freq = min_word_freq
|
|
||||||
self.max_vocab_size = max_vocab_size
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _clean_text(text):
|
|
||||||
text = unicodedata.normalize('NFKC', text)
|
|
||||||
text = re.sub(r'[\r\n\t]+', ' ', text)
|
|
||||||
text = ''.join(ch for ch in text if ch.isprintable())
|
|
||||||
return text
|
|
||||||
|
|
||||||
def build_vocab(self, texts):
|
|
||||||
"""Build word and character vocabs from a list of texts."""
|
|
||||||
word_freq = {}
|
|
||||||
char_set = set()
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
text = self._clean_text(text)
|
|
||||||
for word in text.split():
|
|
||||||
# Preserve Title-case words, lowercase everything else
|
|
||||||
if word[0].isupper() and word[1:].islower():
|
|
||||||
norm = word
|
|
||||||
else:
|
|
||||||
norm = word.lower()
|
|
||||||
word_freq[norm] = word_freq.get(norm, 0) + 1
|
|
||||||
char_set.update(norm)
|
|
||||||
|
|
||||||
# Pick top words by freq
|
|
||||||
words = [
|
|
||||||
w for w, f in sorted(
|
|
||||||
word_freq.items(),
|
|
||||||
key=lambda x: x[1],
|
|
||||||
reverse=True
|
|
||||||
) if f >= self.min_word_freq
|
|
||||||
]
|
|
||||||
avail = self.max_vocab_size - len(self.word_to_id)
|
|
||||||
for w in words[:avail]:
|
|
||||||
if w not in self.word_to_id:
|
|
||||||
self.word_to_id[w] = len(self.word_to_id)
|
|
||||||
|
|
||||||
# Now assign chars after all words
|
|
||||||
idx = len(self.word_to_id)
|
|
||||||
for ch in sorted(char_set):
|
|
||||||
if ch not in self.char_to_id:
|
|
||||||
self.char_to_id[ch] = idx
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(self.vocab_file), exist_ok=True)
|
|
||||||
with open(self.vocab_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump({
|
|
||||||
'word_to_id': self.word_to_id,
|
|
||||||
'char_to_id': self.char_to_id
|
|
||||||
}, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
"""Convert text into a list of token IDs."""
|
|
||||||
text = self._clean_text(text)
|
|
||||||
ids = []
|
|
||||||
for word in text.split():
|
|
||||||
if word[0].isupper() and word[1:].islower():
|
|
||||||
norm = word
|
|
||||||
else:
|
|
||||||
norm = word.lower()
|
|
||||||
if norm in self.word_to_id:
|
|
||||||
ids.append(self.word_to_id[norm])
|
|
||||||
else:
|
|
||||||
for ch in norm:
|
|
||||||
ids.append(
|
|
||||||
self.char_to_id.get(ch, self.word_to_id['<unk>'])
|
|
||||||
)
|
|
||||||
return ids
|
|
||||||
|
|
||||||
def decode(self, ids):
|
|
||||||
"""Convert a list of token IDs back into text."""
|
|
||||||
inv_word = {v: k for k, v in self.word_to_id.items()}
|
|
||||||
inv_char = {v: k for k, v in self.char_to_id.items()}
|
|
||||||
tokens = []
|
|
||||||
for i in ids:
|
|
||||||
if i in inv_word:
|
|
||||||
tokens.append(inv_word[i])
|
|
||||||
else:
|
|
||||||
tokens.append(inv_char.get(i, '<unk>'))
|
|
||||||
return ' '.join(tokens)
|
|
Loading…
x
Reference in New Issue
Block a user