Added all the main code to Ruby
This commit is contained in:
parent
eadbc4a91e
commit
23800bf323
5
.gitignore
vendored
5
.gitignore
vendored
@ -168,5 +168,6 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
books/*
|
||||
*.json
|
||||
data/*
|
||||
*.json
|
||||
models/best_gen.pt
|
||||
|
34
evolution/ga.py
Normal file
34
evolution/ga.py
Normal file
@ -0,0 +1,34 @@
|
||||
import random
|
||||
import copy
|
||||
import torch
|
||||
|
||||
|
||||
def mutate(model, mutation_rate=0.01):
|
||||
new_model = copy.deepcopy(model)
|
||||
for param in new_model.parameters():
|
||||
if random.random() < mutation_rate:
|
||||
noise = torch.randn_like(param) * 0.1
|
||||
param.data += noise
|
||||
return new_model
|
||||
|
||||
|
||||
def crossover(parent1, parent2):
|
||||
child = copy.deepcopy(parent1)
|
||||
for p_child, p2 in zip(child.parameters(), parent2.parameters()):
|
||||
mask = torch.rand_like(p_child) < 0.5
|
||||
p_child.data[mask] = p2.data[mask]
|
||||
return child
|
||||
|
||||
|
||||
def evolve(population, fitnesses, retain_ratio=0.2, mutation_rate=0.1):
|
||||
# rank by fitness (higher is better)
|
||||
paired = sorted(zip(fitnesses, population), key=lambda x: x[0], reverse=True)
|
||||
retain_len = int(len(paired) * retain_ratio)
|
||||
parents = [ind for _, ind in paired[:retain_len]]
|
||||
next_gen = parents.copy()
|
||||
while len(next_gen) < len(population):
|
||||
p1, p2 = random.sample(parents, 2)
|
||||
child = crossover(p1, p2)
|
||||
child = mutate(child, mutation_rate)
|
||||
next_gen.append(child)
|
||||
return next_gen
|
132
main.py
Normal file
132
main.py
Normal file
@ -0,0 +1,132 @@
|
||||
import os
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import discord
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from models.transformer import TransformerGenerator
|
||||
from utils.tokenizer import HybridTokenizer
|
||||
|
||||
# ──────── Setup ────────
|
||||
|
||||
load_dotenv()
|
||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
||||
if not TOKEN:
|
||||
raise RuntimeError("Missing DISCORD_TOKEN in .env")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# print(f"[INFO] Using device: {device}")
|
||||
|
||||
# ──────── Tokenizer & Vocab ────────
|
||||
|
||||
vocab_file = os.path.join("vocab", "vocab.json")
|
||||
tokenizer = HybridTokenizer(vocab_file)
|
||||
|
||||
# If vocab.json doesn’t exist yet, build it from your books:
|
||||
if not tokenizer.char_to_id:
|
||||
book_paths = glob.glob(os.path.join("data", "books", "*.txt"))
|
||||
texts = []
|
||||
for path in book_paths:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
texts.append(f.read())
|
||||
tokenizer.build_vocab(texts)
|
||||
print(f"[INFO] Built vocab ({len(tokenizer.word_to_id)} words + "
|
||||
f"{len(tokenizer.char_to_id)} chars)")
|
||||
|
||||
# ──────── Model Setup ────────
|
||||
|
||||
vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id)
|
||||
embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4
|
||||
max_seq_len = 128
|
||||
|
||||
model = TransformerGenerator(
|
||||
vocab_size, embed_dim, num_heads, mlp_dim, num_layers, max_seq_len
|
||||
).to(device)
|
||||
|
||||
ckpt = os.path.join("models", "best_gen.pt")
|
||||
if os.path.isfile(ckpt):
|
||||
state = torch.load(ckpt, map_location=device)
|
||||
model.load_state_dict(state)
|
||||
print("[INFO] Loaded checkpoint models/best_gen.pt")
|
||||
else:
|
||||
print("[INFO] No checkpoint found; starting from random weights")
|
||||
|
||||
model.eval()
|
||||
|
||||
# ──────── Online Trainer ────────
|
||||
|
||||
class OnlineTrainer:
|
||||
"""Fine-tune the generator on each new exchange."""
|
||||
|
||||
def __init__(self, model, lr=1e-5):
|
||||
self.model = model
|
||||
self.optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
self.device = device
|
||||
|
||||
def train_example(self, text: str):
|
||||
# simple causal training: predict each next token in `text`
|
||||
token_ids = tokenizer.encode(text)
|
||||
if len(token_ids) < 2:
|
||||
return
|
||||
inp = torch.tensor([token_ids[:-1]], device=self.device)
|
||||
tgt = torch.tensor([token_ids[1:]], device=self.device)
|
||||
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
logits = self.model(inp) # (1, seq_len-1, vocab_size)
|
||||
loss = self.criterion(
|
||||
logits.view(-1, logits.size(-1)),
|
||||
tgt.view(-1)
|
||||
)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.model.eval()
|
||||
|
||||
# persist updated weights
|
||||
os.makedirs("models", exist_ok=True)
|
||||
torch.save(self.model.state_dict(), ckpt)
|
||||
|
||||
trainer = OnlineTrainer(model)
|
||||
|
||||
# ──────── Discord Client ────────
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = discord.Client(intents=intents)
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Ruby is online as {client.user}")
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
# ignore Ruby’s own messages
|
||||
if message.author == client.user:
|
||||
return
|
||||
|
||||
content = message.content.strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
# → Generate Ruby’s reply
|
||||
ids = tokenizer.encode(content)
|
||||
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
||||
with torch.no_grad():
|
||||
out_ids = model(inp).argmax(-1).squeeze().cpu().tolist()
|
||||
reply = tokenizer.decode(out_ids)
|
||||
|
||||
await message.channel.send(reply)
|
||||
|
||||
# → Optionally train on this new example
|
||||
sample = f"User: {content}\nRuby: {reply}"
|
||||
trainer.train_example(sample)
|
||||
|
||||
# ──────── Run ────────
|
||||
|
||||
client.run(TOKEN)
|
40
models/discriminator.py
Normal file
40
models/discriminator.py
Normal file
@ -0,0 +1,40 @@
|
||||
import os
|
||||
|
||||
import discord
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from ruby_heart import RubyHeart
|
||||
|
||||
load_dotenv()
|
||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
||||
if not TOKEN:
|
||||
raise RuntimeError("DISCORD_TOKEN missing in .env")
|
||||
|
||||
# instantiate your “Ruby” engine
|
||||
ruby = RubyHeart() # uses GPU if available
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = discord.Client(intents=intents)
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Ruby is online as {client.user}")
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
if message.author == client.user:
|
||||
return
|
||||
content = message.content.strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
# generate + train in one call
|
||||
reply = ruby.generate(content)
|
||||
await message.channel.send(reply)
|
||||
ruby.train_on(f"User: {content}\nRuby: {reply}")
|
||||
|
||||
|
||||
client.run(TOKEN)
|
80
models/transformer.py
Normal file
80
models/transformer.py
Normal file
@ -0,0 +1,80 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MultiHeadSelfAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
def forward(self, x):
|
||||
# x: (batch, seq_len, embed_dim)
|
||||
b, t, e = x.size()
|
||||
qkv = self.qkv_proj(x) # (b, t, 3*e)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
# reshape for multi-head
|
||||
q = q.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
out = torch.matmul(attn, v).transpose(1, 2).contiguous()
|
||||
out = out.view(b, t, e)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
|
||||
super().__init__()
|
||||
self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
|
||||
self.ln1 = nn.LayerNorm(embed_dim)
|
||||
self.ff = nn.Sequential(
|
||||
nn.Linear(embed_dim, mlp_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(mlp_dim, embed_dim),
|
||||
)
|
||||
self.ln2 = nn.LayerNorm(embed_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.dropout(self.attn(self.ln1(x)))
|
||||
x = x + self.dropout(self.ff(self.ln2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class TransformerGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int,
|
||||
num_layers: int,
|
||||
max_seq_len: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.token_emb = nn.Embedding(vocab_size, embed_dim)
|
||||
self.pos_emb = nn.Embedding(max_seq_len, embed_dim)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TransformerBlock(embed_dim, num_heads, mlp_dim)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.ln = nn.LayerNorm(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, vocab_size)
|
||||
|
||||
def forward(self, x):
|
||||
# x: (batch, seq_len)
|
||||
b, t = x.size()
|
||||
positions = torch.arange(t, device=x.device).unsqueeze(0)
|
||||
x = self.token_emb(x) + self.pos_emb(positions)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
x = self.ln(x)
|
||||
return self.head(x)
|
116
ruby_heart.py
Normal file
116
ruby_heart.py
Normal file
@ -0,0 +1,116 @@
|
||||
import glob
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from models.transformer import TransformerGenerator
|
||||
from models.discriminator import Discriminator
|
||||
from utils.tokenizer import HybridTokenizer
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class RubyHeart:
|
||||
def __init__(
|
||||
self,
|
||||
books_dir="data/books",
|
||||
vocab_file="vocab/vocab.json",
|
||||
model_file="models/best_gen.pt",
|
||||
device=None,
|
||||
):
|
||||
self.device = device or torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
# tokenizer & vocab
|
||||
self.tokenizer = HybridTokenizer(vocab_file)
|
||||
if not self.tokenizer.char_to_id:
|
||||
self._build_vocab(books_dir)
|
||||
|
||||
# model init
|
||||
vs = (
|
||||
len(self.tokenizer.word_to_id)
|
||||
+ len(self.tokenizer.char_to_id)
|
||||
)
|
||||
self.model = TransformerGenerator(
|
||||
vocab_size=vs,
|
||||
embed_dim=256,
|
||||
num_heads=8,
|
||||
mlp_dim=512,
|
||||
num_layers=4,
|
||||
max_seq_len=128,
|
||||
).to(self.device)
|
||||
|
||||
self.model_file = model_file
|
||||
self._load_checkpoint(model_file)
|
||||
|
||||
# online trainer
|
||||
self.trainer = self._make_trainer()
|
||||
|
||||
def _build_vocab(self, books_dir):
|
||||
paths = glob.glob(os.path.join(books_dir, "*.txt"))
|
||||
texts = [open(p, encoding="utf-8").read() for p in paths]
|
||||
self.tokenizer.build_vocab(texts)
|
||||
|
||||
def _load_checkpoint(self, path):
|
||||
if os.path.isfile(path):
|
||||
state = torch.load(path, map_location=self.device,
|
||||
weights_only=True)
|
||||
self.model.load_state_dict(state)
|
||||
# else: start from scratch
|
||||
|
||||
def _make_trainer(self, lr=1e-5):
|
||||
opt = optim.Adam(self.model.parameters(), lr=lr)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
return {"opt": opt, "loss": loss_fn}
|
||||
|
||||
@staticmethod
|
||||
def _top_k_top_p(logits, top_k=50, top_p=0.9):
|
||||
# (same filtering code as before)
|
||||
if top_k > 0:
|
||||
kth = torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits = torch.where(
|
||||
logits < kth, float("-inf"), logits
|
||||
)
|
||||
if top_p > 0.0:
|
||||
sorted_logits, indices = torch.sort(
|
||||
logits, descending=True
|
||||
)
|
||||
cum_probs = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
|
||||
mask = cum_probs > top_p
|
||||
mask[..., 1:] = mask[..., :-1].clone()
|
||||
mask[..., 0] = False
|
||||
remove = indices[mask]
|
||||
logits[remove] = float("-inf")
|
||||
return logits
|
||||
|
||||
def generate(self, prompt, max_len=64, temp=1.0, top_k=50, top_p=0.9):
|
||||
self.model.eval()
|
||||
ids = self.tokenizer.encode(prompt)
|
||||
input_ids = torch.tensor([ids], device=self.device)
|
||||
with torch.no_grad():
|
||||
for _ in range(max_len):
|
||||
logits = self.model(input_ids)[:, -1, :] / temp
|
||||
filt = self._top_k_top_p(logits, top_k, top_p)
|
||||
probs = F.softmax(filt, dim=-1)
|
||||
nxt = torch.multinomial(probs, 1)
|
||||
input_ids = torch.cat([input_ids, nxt], dim=-1)
|
||||
return self.tokenizer.decode(input_ids[0].cpu().tolist())
|
||||
|
||||
def train_on(self, text):
|
||||
ids = self.tokenizer.encode(text)
|
||||
if len(ids) < 2:
|
||||
return
|
||||
inp = torch.tensor([ids[:-1]], device=self.device)
|
||||
tgt = torch.tensor([ids[1:]], device=self.device)
|
||||
self.model.train()
|
||||
self.trainer["opt"].zero_grad()
|
||||
logits = self.model(inp)
|
||||
loss = self.trainer["loss"](
|
||||
logits.view(-1, logits.size(-1)),
|
||||
tgt.view(-1),
|
||||
)
|
||||
loss.backward()
|
||||
self.trainer["opt"].step()
|
||||
torch.save(self.model.state_dict(), self.model_file)
|
||||
self.model.eval()
|
93
training/train.py
Normal file
93
training/train.py
Normal file
@ -0,0 +1,93 @@
|
||||
import glob
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from evolution.ga import evolve
|
||||
from models.transformer import TransformerGenerator
|
||||
from models.discriminator import Discriminator
|
||||
from utils.tokenizer import HybridTokenizer
|
||||
|
||||
|
||||
def chunked(lst, size):
|
||||
"""Yield successive chunks from a list."""
|
||||
for i in range(0, len(lst), size):
|
||||
yield lst[i:i + size]
|
||||
|
||||
|
||||
def train():
|
||||
vocab_file = os.path.join('vocab', 'vocab.json')
|
||||
tokenizer = HybridTokenizer(vocab_file)
|
||||
book_paths = glob.glob(os.path.join('data', 'books', '*.txt'))
|
||||
texts = []
|
||||
for path in book_paths:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
texts.append(f.read())
|
||||
|
||||
if not tokenizer.char_to_id:
|
||||
tokenizer.build_vocab(texts)
|
||||
|
||||
seq_len = 128
|
||||
sequences = []
|
||||
for text in texts:
|
||||
token_ids = tokenizer.encode(text)
|
||||
for i in range(0, len(token_ids) - seq_len, seq_len):
|
||||
sequences.append(
|
||||
torch.tensor(token_ids[i:i + seq_len], dtype=torch.long)
|
||||
)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
pop_size, generations = 10, 50
|
||||
vocab_size = len(tokenizer.word_to_id) + len(tokenizer.char_to_id)
|
||||
embed_dim, num_heads, mlp_dim, num_layers = 256, 8, 512, 4
|
||||
|
||||
population = [
|
||||
TransformerGenerator(
|
||||
vocab_size, embed_dim, num_heads, mlp_dim, num_layers, seq_len
|
||||
).to(device)
|
||||
for _ in range(pop_size)
|
||||
]
|
||||
discriminator = Discriminator(vocab_size, embed_dim).to(device)
|
||||
disc_opt = optim.Adam(discriminator.parameters(), lr=1e-4)
|
||||
bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
for gen_idx in range(generations):
|
||||
# Evaluate fitness
|
||||
fitnesses = []
|
||||
for g in population:
|
||||
inp = torch.randint(0, vocab_size, (1, seq_len), device=device)
|
||||
out = g(inp).argmax(-1)
|
||||
score = discriminator(out)
|
||||
fitnesses.append(-bce(score, torch.ones_like(score)).item())
|
||||
|
||||
# Train discriminator
|
||||
for batch in chunked(sequences, 8):
|
||||
real = torch.stack(batch).to(device)
|
||||
fake_in = torch.randint(0, vocab_size, real.shape, device=device)
|
||||
fake = population[0](fake_in).argmax(-1).detach()
|
||||
|
||||
disc_opt.zero_grad()
|
||||
loss_r = bce(
|
||||
discriminator(real),
|
||||
torch.ones(real.size(0), 1, device=device)
|
||||
)
|
||||
loss_f = bce(
|
||||
discriminator(fake),
|
||||
torch.zeros(fake.size(0), 1, device=device)
|
||||
)
|
||||
(loss_r + loss_f).div_(2).backward()
|
||||
disc_opt.step()
|
||||
|
||||
# Evolve population
|
||||
population = evolve(population, fitnesses)
|
||||
print(f'Gen {gen_idx:03d}: best fitness = {max(fitnesses):.4f}')
|
||||
|
||||
os.makedirs('models', exist_ok=True)
|
||||
best = population[fitnesses.index(max(fitnesses))]
|
||||
torch.save(best.state_dict(), 'models/best_gen.pt')
|
||||
|
||||
|
||||
# kick off training immediately (no __main__ guard)
|
||||
train()
|
106
utils/tokenizer.py
Normal file
106
utils/tokenizer.py
Normal file
@ -0,0 +1,106 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
|
||||
class HybridTokenizer:
|
||||
"""Hybrid word/character tokenizer with vocab persistence."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
min_word_freq=5,
|
||||
max_vocab_size=10000
|
||||
):
|
||||
self.vocab_file = vocab_file
|
||||
if os.path.exists(vocab_file):
|
||||
with open(vocab_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
self.word_to_id = data.get('word_to_id', {})
|
||||
self.char_to_id = data.get('char_to_id', {})
|
||||
else:
|
||||
self.word_to_id = {'<unk>': 0}
|
||||
self.char_to_id = {}
|
||||
self.min_word_freq = min_word_freq
|
||||
self.max_vocab_size = max_vocab_size
|
||||
|
||||
@staticmethod
|
||||
def _clean_text(text):
|
||||
text = unicodedata.normalize('NFKC', text)
|
||||
text = re.sub(r'[\r\n\t]+', ' ', text)
|
||||
text = ''.join(ch for ch in text if ch.isprintable())
|
||||
return text
|
||||
|
||||
def build_vocab(self, texts):
|
||||
"""Build word and character vocabs from a list of texts."""
|
||||
word_freq = {}
|
||||
char_set = set()
|
||||
|
||||
for text in texts:
|
||||
text = self._clean_text(text)
|
||||
for word in text.split():
|
||||
# Preserve Title-case words, lowercase everything else
|
||||
if word[0].isupper() and word[1:].islower():
|
||||
norm = word
|
||||
else:
|
||||
norm = word.lower()
|
||||
word_freq[norm] = word_freq.get(norm, 0) + 1
|
||||
char_set.update(norm)
|
||||
|
||||
# Pick top words by freq
|
||||
words = [
|
||||
w for w, f in sorted(
|
||||
word_freq.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
) if f >= self.min_word_freq
|
||||
]
|
||||
avail = self.max_vocab_size - len(self.word_to_id)
|
||||
for w in words[:avail]:
|
||||
if w not in self.word_to_id:
|
||||
self.word_to_id[w] = len(self.word_to_id)
|
||||
|
||||
# Now assign chars after all words
|
||||
idx = len(self.word_to_id)
|
||||
for ch in sorted(char_set):
|
||||
if ch not in self.char_to_id:
|
||||
self.char_to_id[ch] = idx
|
||||
idx += 1
|
||||
|
||||
os.makedirs(os.path.dirname(self.vocab_file), exist_ok=True)
|
||||
with open(self.vocab_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
'word_to_id': self.word_to_id,
|
||||
'char_to_id': self.char_to_id
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def encode(self, text):
|
||||
"""Convert text into a list of token IDs."""
|
||||
text = self._clean_text(text)
|
||||
ids = []
|
||||
for word in text.split():
|
||||
if word[0].isupper() and word[1:].islower():
|
||||
norm = word
|
||||
else:
|
||||
norm = word.lower()
|
||||
if norm in self.word_to_id:
|
||||
ids.append(self.word_to_id[norm])
|
||||
else:
|
||||
for ch in norm:
|
||||
ids.append(
|
||||
self.char_to_id.get(ch, self.word_to_id['<unk>'])
|
||||
)
|
||||
return ids
|
||||
|
||||
def decode(self, ids):
|
||||
"""Convert a list of token IDs back into text."""
|
||||
inv_word = {v: k for k, v in self.word_to_id.items()}
|
||||
inv_char = {v: k for k, v in self.char_to_id.items()}
|
||||
tokens = []
|
||||
for i in ids:
|
||||
if i in inv_word:
|
||||
tokens.append(inv_word[i])
|
||||
else:
|
||||
tokens.append(inv_char.get(i, '<unk>'))
|
||||
return ' '.join(tokens)
|
Loading…
x
Reference in New Issue
Block a user