Doing some updates to the reader part of Ruby.
This commit is contained in:
parent
287eee0d34
commit
7823ce1d5e
3
.gitignore
vendored
3
.gitignore
vendored
@ -172,4 +172,5 @@ cython_debug/
|
|||||||
/logs/core_dreams.txt
|
/logs/core_dreams.txt
|
||||||
/logs/best_dream.txt
|
/logs/best_dream.txt
|
||||||
/.vscode/launch.json
|
/.vscode/launch.json
|
||||||
/books
|
/books
|
||||||
|
/readstate.txt
|
27
dashboard.py
27
dashboard.py
@ -1,8 +1,8 @@
|
|||||||
from flask import Flask, render_template_string
|
from flask import Flask, render_template_string
|
||||||
from datetime import datetime
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
ruby_client = None # This will be set externally
|
||||||
|
|
||||||
|
|
||||||
def tail(filepath, num_lines=10):
|
def tail(filepath, num_lines=10):
|
||||||
@ -31,6 +31,17 @@ def home():
|
|||||||
errors = [line.strip() for line in tail("logs/error.log", 15)]
|
errors = [line.strip() for line in tail("logs/error.log", 15)]
|
||||||
best_dream = get_best_dream()
|
best_dream = get_best_dream()
|
||||||
|
|
||||||
|
# Handle book progress if Ruby has a reader
|
||||||
|
book = {
|
||||||
|
"book": "Not reading",
|
||||||
|
"line": 0,
|
||||||
|
"total": 0,
|
||||||
|
"percent": 0.0,
|
||||||
|
"last_sentence": ""
|
||||||
|
}
|
||||||
|
if ruby_client and hasattr(ruby_client, "reader"):
|
||||||
|
book = ruby_client.reader.progress()
|
||||||
|
|
||||||
return render_template_string("""
|
return render_template_string("""
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html>
|
<html>
|
||||||
@ -47,6 +58,11 @@ def home():
|
|||||||
<body>
|
<body>
|
||||||
<h1>🌸 Ruby's Dashboard</h1>
|
<h1>🌸 Ruby's Dashboard</h1>
|
||||||
<p><b>Vocabulary Size:</b> {{ vocab_size }}</p>
|
<p><b>Vocabulary Size:</b> {{ vocab_size }}</p>
|
||||||
|
|
||||||
|
<h3>📖 Book Progress</h3>
|
||||||
|
<p><b>{{ book.book }}</b> – Line {{ book.line }} of {{ book.total }} ({{ book.percent }}%)</p>
|
||||||
|
<p><i>{{ book.last_sentence }}</i></p>
|
||||||
|
|
||||||
<h3>🏆 Highest Scoring Dream</h3>
|
<h3>🏆 Highest Scoring Dream</h3>
|
||||||
<p><b>{{ best_dream }}</b></p>
|
<p><b>{{ best_dream }}</b></p>
|
||||||
|
|
||||||
@ -73,8 +89,11 @@ def home():
|
|||||||
|
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
""", best_dream=best_dream, dreams=dreams[::-1], messages=messages[::-1], errors=errors[::-1], vocab_size=vocab_size)
|
""", best_dream=best_dream, dreams=dreams[::-1], messages=messages[::-1], errors=errors[::-1], vocab_size=vocab_size, book=book)
|
||||||
|
|
||||||
|
|
||||||
def start_dashboard():
|
def start_dashboard_background():
|
||||||
app.run(debug=False, host="0.0.0.0", port=5000)
|
import threading
|
||||||
|
thread = threading.Thread(target=lambda: app.run(debug=False, host="0.0.0.0", port=5000))
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
15
main.py
15
main.py
@ -2,10 +2,9 @@ import discord
|
|||||||
import asyncio
|
import asyncio
|
||||||
import atexit
|
import atexit
|
||||||
import os
|
import os
|
||||||
import threading
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from dashboard import start_dashboard
|
import dashboard
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
from trainer import RubyTrainer
|
from trainer import RubyTrainer
|
||||||
from reader import BookReader
|
from reader import BookReader
|
||||||
@ -66,6 +65,8 @@ class Ruby(discord.Client):
|
|||||||
print(f"[READY] Logged in as {self.user} (ID: {self.user.id})")
|
print(f"[READY] Logged in as {self.user} (ID: {self.user.id})")
|
||||||
await self.set_activity("you...")
|
await self.set_activity("you...")
|
||||||
self.trainer.reinforce_core_memory()
|
self.trainer.reinforce_core_memory()
|
||||||
|
# self.trainer.clean_vocab()
|
||||||
|
# self.trainer.rebuild_model_if_needed()
|
||||||
|
|
||||||
async def idle_dream_loop(self):
|
async def idle_dream_loop(self):
|
||||||
await self.wait_until_ready()
|
await self.wait_until_ready()
|
||||||
@ -85,7 +86,7 @@ class Ruby(discord.Client):
|
|||||||
speak = random() < 0.5
|
speak = random() < 0.5
|
||||||
thought = self.trainer.daydream(say_thought=speak)
|
thought = self.trainer.daydream(say_thought=speak)
|
||||||
|
|
||||||
if speak and thought and len(thought.split()) >=4:
|
if speak and thought and len(thought.split()) >= 4:
|
||||||
for guild in self.guilds:
|
for guild in self.guilds:
|
||||||
for channel in guild.text_channels:
|
for channel in guild.text_channels:
|
||||||
if channel.permissions_for(guild.me).send_messages:
|
if channel.permissions_for(guild.me).send_messages:
|
||||||
@ -125,15 +126,14 @@ class Ruby(discord.Client):
|
|||||||
def train_on_message(self, message: discord.Message):
|
def train_on_message(self, message: discord.Message):
|
||||||
text = message.content.strip()
|
text = message.content.strip()
|
||||||
self.trainer.train_on_tokens_from_text(text)
|
self.trainer.train_on_tokens_from_text(text)
|
||||||
token_tensor = torch.tensor(tokens, dtype=torch.long)
|
|
||||||
loss = train_on_tokens(self.model, tokens, self.optimizer, self.criterion, device="cpu")
|
|
||||||
print(f"[TRAIN] Tokens: {tokens} | Loss: {loss:.4f}")
|
|
||||||
|
|
||||||
|
|
||||||
# Run Ruby
|
# Run Ruby
|
||||||
client = None
|
client = None
|
||||||
try:
|
try:
|
||||||
client = Ruby()
|
client = Ruby()
|
||||||
|
dashboard.ruby_client = client
|
||||||
|
dashboard.start_dashboard_background()
|
||||||
|
|
||||||
def on_exit():
|
def on_exit():
|
||||||
if client:
|
if client:
|
||||||
@ -142,8 +142,7 @@ try:
|
|||||||
client.trainer.daydream(rounds=10)
|
client.trainer.daydream(rounds=10)
|
||||||
|
|
||||||
atexit.register(on_exit)
|
atexit.register(on_exit)
|
||||||
dashboard_thread = threading.Thread(target=start_dashboard, daemon=True)
|
dashboard.start_dashboard_background()
|
||||||
dashboard_thread.start()
|
|
||||||
client.run(TOKEN)
|
client.run(TOKEN)
|
||||||
finally:
|
finally:
|
||||||
if client is not None:
|
if client is not None:
|
||||||
|
35
reader.py
35
reader.py
@ -2,14 +2,17 @@ import os
|
|||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
class BookReader:
|
class BookReader:
|
||||||
def __init__(self, trainer, book_path, state_path="readstate.txt", log_path="logs/read.log", interval=180):
|
def __init__(self, trainer, book_path, state_path="readstate.txt", log_path="logs/read.log", interval=15):
|
||||||
self.trainer = trainer
|
self.trainer = trainer
|
||||||
self.book_path = book_path
|
self.book_path = book_path
|
||||||
self.state_path = state_path
|
self.state_path = state_path
|
||||||
self.log_path = log_path
|
self.log_path = log_path
|
||||||
self.interval = interval # seconds between reading cycles
|
self.interval = interval
|
||||||
self.current_line = 0
|
self.current_line = 0
|
||||||
|
self.last_sentence = ""
|
||||||
|
self.total_lines = 0
|
||||||
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
|
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
|
||||||
|
|
||||||
if os.path.exists(self.state_path):
|
if os.path.exists(self.state_path):
|
||||||
@ -19,35 +22,53 @@ class BookReader:
|
|||||||
except Exception:
|
except Exception:
|
||||||
self.current_line = 0
|
self.current_line = 0
|
||||||
|
|
||||||
|
if os.path.exists(self.book_path):
|
||||||
|
with open(self.book_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||||
|
self.total_lines = len(f.readlines())
|
||||||
|
|
||||||
def _save_state(self):
|
def _save_state(self):
|
||||||
with open(self.state_path, "w", encoding="utf-8") as f:
|
with open(self.state_path, "w", encoding="utf-8") as f:
|
||||||
f.write(str(self.current_line))
|
f.write(str(self.current_line))
|
||||||
|
|
||||||
def _log_read(self, text: str, score: float, tag: str = "Book"):
|
def _log_read(self, text: str, score: float, tag: str = "Book"):
|
||||||
with open(self.log_path, "a", encoding="utf-8") as f:
|
with open(self.log_path, "a", encoding="utf-8") as f:
|
||||||
f.write(f"[{datetime.utcnow().isoformat()}] ({tag}) {score:.2f} | {text.strip()}\\n")
|
f.write(f"[{datetime.utcnow().isoformat()}] ({tag}) {score:.2f} | {text.strip()}\n")
|
||||||
|
|
||||||
async def start_reading(self):
|
async def start_reading(self):
|
||||||
if not os.path.exists(self.book_path):
|
if not os.path.exists(self.book_path):
|
||||||
print(f"[BOOK] File not found: {self.book_path}")
|
print(f"[BOOK] File not found: {self.book_path}")
|
||||||
return
|
return
|
||||||
|
|
||||||
with open(self.book_path, "r", encoding="utf-8") as f:
|
with open(self.book_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
self.total_lines = len(lines)
|
||||||
|
|
||||||
print(f"[BOOK] Starting to read {self.book_path} from line {self.current_line}...")
|
print(f"[BOOK] Starting to read {self.book_path} from line {self.current_line}...")
|
||||||
|
|
||||||
while self.current_line < len(lines):
|
while self.current_line < self.total_lines:
|
||||||
passage = lines[self.current_line].strip()
|
passage = lines[self.current_line].strip()
|
||||||
|
|
||||||
if len(passage.split()) >= 5:
|
if len(passage.split()) >= 5 and self._is_valid(passage):
|
||||||
score = self.trainer.score_sentence(passage)
|
score = self.trainer.score_sentence(passage)
|
||||||
if self.trainer.is_reinforceable(passage) and score >= 2.5:
|
if self.trainer.is_reinforceable(passage) and score >= 2.5:
|
||||||
self.trainer.train_on_tokens_from_text(passage)
|
self.trainer.train_on_tokens_from_text(passage)
|
||||||
self._log_read(passage, score)
|
self._log_read(passage, score)
|
||||||
|
self.last_sentence = passage
|
||||||
|
|
||||||
self.current_line += 1
|
self.current_line += 1
|
||||||
self._save_state()
|
self._save_state()
|
||||||
await asyncio.sleep(self.interval)
|
await asyncio.sleep(self.interval)
|
||||||
|
|
||||||
print("[BOOK] Finished reading the book.")
|
print("[BOOK] Finished reading the book.")
|
||||||
|
|
||||||
|
def _is_valid(self, text: str) -> bool:
|
||||||
|
return all(c.isprintable() or c.isspace() for c in text)
|
||||||
|
|
||||||
|
def progress(self) -> dict:
|
||||||
|
return {
|
||||||
|
"book": os.path.basename(self.book_path),
|
||||||
|
"line": self.current_line,
|
||||||
|
"total": self.total_lines,
|
||||||
|
"percent": round(100 * self.current_line / self.total_lines, 2) if self.total_lines else 0.0,
|
||||||
|
"last_sentence": self.last_sentence
|
||||||
|
}
|
||||||
|
11
tokenizer.py
11
tokenizer.py
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from trainer import normalize_for_vocab
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
@ -13,21 +14,21 @@ class Tokenizer:
|
|||||||
return
|
return
|
||||||
with open(self.vocab_path, "r", encoding="utf-8") as f:
|
with open(self.vocab_path, "r", encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
token, idx = line.strip().split("\t")
|
token = line.strip()
|
||||||
self.vocab[token] = int(idx)
|
if token and token not in self.vocab:
|
||||||
if token not in self.vocab:
|
idx = len(self.vocab)
|
||||||
self.vocab[token] = idx
|
self.vocab[token] = idx
|
||||||
self.inv_vocab[idx] = token
|
self.inv_vocab[idx] = token
|
||||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
|
||||||
|
|
||||||
def save_vocab(self):
|
def save_vocab(self):
|
||||||
with open(self.vocab_path, "w", encoding="utf-8") as f:
|
with open(self.vocab_path, "w", encoding="utf-8") as f:
|
||||||
for token, idx in self.vocab.items():
|
for token, idx in self.vocab.items():
|
||||||
f.write(f"{token}\t{idx}\n")
|
f.write(f"{token}\n")
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
tokens = []
|
tokens = []
|
||||||
for word in text.strip().split():
|
for word in text.strip().split():
|
||||||
|
word = normalize_for_vocab(word)
|
||||||
if word not in self.vocab:
|
if word not in self.vocab:
|
||||||
self.vocab[word] = len(self.vocab)
|
self.vocab[word] = len(self.vocab)
|
||||||
self.inv_vocab[self.vocab[word]] = word
|
self.inv_vocab[self.vocab[word]] = word
|
||||||
|
119
trainer.py
119
trainer.py
@ -3,10 +3,34 @@ import torch.nn.functional as F
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import string
|
||||||
from model import MiniGPT
|
from model import MiniGPT
|
||||||
|
|
||||||
# flake8: noqa E501
|
# flake8: noqa E501
|
||||||
|
|
||||||
|
def normalize_for_vocab(text: str) -> str:
|
||||||
|
# Replace em-dashes and smart quotes with standard forms
|
||||||
|
text = text.replace("—", " ").replace("“", '"').replace("”", '"').replace("‘", "'").replace("’", "'")
|
||||||
|
|
||||||
|
# Remove parenthetical and bracket content
|
||||||
|
text = re.sub(r"\[(.*?)\]", "", text)
|
||||||
|
text = re.sub(r"\((.*?)\)", "", text)
|
||||||
|
|
||||||
|
# Remove trailing punctuation (commas, periods, question marks, etc.) per word
|
||||||
|
text = re.sub(r"(\w)[.,!?;:]+(?=\s|$)", r"\1", text)
|
||||||
|
|
||||||
|
# Remove quotes at start or end of lines
|
||||||
|
text = text.strip("\"'")
|
||||||
|
|
||||||
|
# Normalize hyphenated words by collapsing to a single word
|
||||||
|
text = re.sub(r"(\w)-(\w)", r"\1\2", text)
|
||||||
|
|
||||||
|
# Remove duplicate spaces and lowercase
|
||||||
|
text = re.sub(r"\s+", " ", text).strip().lower()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
class RubyTrainer:
|
class RubyTrainer:
|
||||||
def __init__(self, tokenizer, embed_dim=128, n_heads=4, n_layers=2, max_len=128):
|
def __init__(self, tokenizer, embed_dim=128, n_heads=4, n_layers=2, max_len=128):
|
||||||
@ -34,7 +58,8 @@ class RubyTrainer:
|
|||||||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
||||||
|
|
||||||
def train_on_tokens_from_text(self, text: str):
|
def train_on_tokens_from_text(self, text: str):
|
||||||
tokens = self.tokenizer.tokenize(text)
|
normalized = normalize_for_vocab(text)
|
||||||
|
tokens = self.tokenizer.tokenize(normalized)
|
||||||
if not tokens:
|
if not tokens:
|
||||||
return
|
return
|
||||||
tokens = [self.tokenizer.vocab["<START>"]] + tokens + [self.tokenizer.vocab["<END>"]]
|
tokens = [self.tokenizer.vocab["<START>"]] + tokens + [self.tokenizer.vocab["<END>"]]
|
||||||
@ -54,44 +79,62 @@ class RubyTrainer:
|
|||||||
|
|
||||||
print(f"[TRAIN] Tokens: {tokens} | Loss: {loss.item():.4f}")
|
print(f"[TRAIN] Tokens: {tokens} | Loss: {loss.item():.4f}")
|
||||||
|
|
||||||
def generate_reply(self, prompt=None, max_length=20):
|
def generate_reply(self, prompt=None, max_length=20, temperature=1.3):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], device=self.device)
|
input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], device=self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(max_length):
|
for _ in range(max_length):
|
||||||
|
max_id = self.model.token_embed.num_embeddings
|
||||||
|
input_ids = torch.clamp(input_ids, 0, max_id - 1)
|
||||||
output = self.model(input_ids)
|
output = self.model(input_ids)
|
||||||
logits = output[:, -1, :]
|
logits = output[:, -1, :]
|
||||||
|
|
||||||
# Apply repeat penalty BEFORE sampling
|
# Apply repeat penalty
|
||||||
if input_ids.size(1) >= 2:
|
if input_ids.size(1) >= 2:
|
||||||
last_token = input_ids[0, -1].item()
|
last_token = input_ids[0, -1].item()
|
||||||
logits[0, last_token] *= 0.1 # Penalize repeating same token again
|
logits[0, last_token] *= 0.1
|
||||||
|
|
||||||
|
# 🔥 Temperature sampling
|
||||||
|
probs = F.softmax(logits / temperature, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, 1)[0].view(1)
|
||||||
|
|
||||||
|
if next_token.item() >= self.model.token_embed.num_embeddings:
|
||||||
|
print("[ERROR] Token index out of bounds. Rebuilding model...")
|
||||||
|
self.rebuild_model_if_needed()
|
||||||
|
return ""
|
||||||
|
|
||||||
next_token = torch.argmax(logits, dim=-1)
|
|
||||||
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
||||||
|
|
||||||
if next_token.item() == self.tokenizer.vocab["<END>"]:
|
if next_token.item() == self.tokenizer.vocab["<END>"]:
|
||||||
break
|
break
|
||||||
|
|
||||||
output = self.tokenizer.detokenize(input_ids.squeeze().tolist())
|
output = self.tokenizer.detokenize(input_ids.squeeze().tolist())
|
||||||
output = output.replace("<START>", "").replace("<END>", "").strip()
|
return output.replace("<START>", "").replace("<END>", "").strip()
|
||||||
return output
|
|
||||||
|
|
||||||
def self_rephrase(self, original: str, max_tokens=50):
|
|
||||||
|
def self_rephrase(self, original: str, max_tokens=50, temperature=1.3):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
tokens = [self.tokenizer.vocab["<START>"]] + self.tokenizer.tokenize(original)
|
tokens = [self.tokenizer.vocab["<START>"]] + self.tokenizer.tokenize(original)
|
||||||
input_ids = torch.tensor(tokens, dtype=torch.long, device=self.device).unsqueeze(0)
|
input_ids = torch.tensor(tokens, dtype=torch.long, device=self.device).unsqueeze(0)
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
input_ids = torch.clamp(input_ids, 0, self.model.token_embed.num_embeddings - 1)
|
||||||
out = self.model(input_ids)
|
out = self.model(input_ids)
|
||||||
logits = out[:, -1, :] / 1.1
|
logits = out[:, -1, :] / 1.1
|
||||||
if input_ids.size(1) < 8:
|
if input_ids.size(1) < 8:
|
||||||
logits[0, self.tokenizer.vocab["<END>"]] = float("-inf")
|
logits[0, self.tokenizer.vocab["<END>"]] = float("-inf")
|
||||||
|
|
||||||
probs = F.softmax(logits, dim=-1)
|
probs = F.softmax(logits / temperature, dim=-1)
|
||||||
next_token = torch.multinomial(probs, 1)[0].view(1, 1)
|
next_token = torch.multinomial(probs, 1)[0].view(1, 1)
|
||||||
|
|
||||||
|
# ✅ Ensure next_token is valid
|
||||||
|
if next_token.item() >= self.model.token_embed.num_embeddings:
|
||||||
|
print("[ERROR] Token index out of bounds in self_rephrase. Rebuilding model...")
|
||||||
|
self.rebuild_model_if_needed()
|
||||||
|
return ""
|
||||||
|
|
||||||
input_ids = torch.cat([input_ids, next_token], dim=1)
|
input_ids = torch.cat([input_ids, next_token], dim=1)
|
||||||
|
|
||||||
if next_token.item() == self.tokenizer.vocab["<END>"]:
|
if next_token.item() == self.tokenizer.vocab["<END>"]:
|
||||||
@ -220,9 +263,20 @@ class RubyTrainer:
|
|||||||
if text.lower().count("i am") > len(text.split()) * 0.25:
|
if text.lower().count("i am") > len(text.split()) * 0.25:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Reject if the first word is repeated 3+ times
|
# Reject if first word repeats 3 times ("you you you")
|
||||||
if words[:3].count(words[0]) == 3:
|
if words[:3].count(words[0]) == 3:
|
||||||
return False # "you you you" type
|
return False
|
||||||
|
|
||||||
|
# 🧠 NEW: Reject if starts with common book phrases
|
||||||
|
banned_starts = ("once upon", "chapter", "the end", "in which", "it was", "quick cried", "they are course")
|
||||||
|
lowered = text.lower()
|
||||||
|
if any(lowered.startswith(phrase) for phrase in banned_starts):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 🧠 NEW: Reject if too many capitalized words in a row (e.g., names, places from a book)
|
||||||
|
cap_sequence = sum(1 for word in words if word.istitle())
|
||||||
|
if cap_sequence > 5 and cap_sequence / len(words) > 0.4:
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -249,3 +303,46 @@ class RubyTrainer:
|
|||||||
base_score -= 2
|
base_score -= 2
|
||||||
|
|
||||||
return max(0.0, base_score)
|
return max(0.0, base_score)
|
||||||
|
|
||||||
|
def clean_vocab(self, min_occurrences: int = 1):
|
||||||
|
print("[CLEAN] Analyzing and cleaning vocabulary...")
|
||||||
|
|
||||||
|
# Count normalized forms
|
||||||
|
counts = Counter()
|
||||||
|
norm_to_original = {}
|
||||||
|
|
||||||
|
for word in self.tokenizer.vocab:
|
||||||
|
if word in ("<START>", "<END>"):
|
||||||
|
continue
|
||||||
|
normalized = normalize_for_vocab(word)
|
||||||
|
if normalized not in norm_to_original:
|
||||||
|
norm_to_original[normalized] = word
|
||||||
|
counts[normalized] += 1
|
||||||
|
|
||||||
|
# Rebuild new vocab
|
||||||
|
new_vocab = {"<START>": 0, "<END>": 1}
|
||||||
|
reverse = dict()
|
||||||
|
|
||||||
|
idx = 2
|
||||||
|
for norm, original in norm_to_original.items():
|
||||||
|
if counts[norm] >= min_occurrences:
|
||||||
|
new_vocab[original] = idx
|
||||||
|
reverse[norm] = original
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
old_size = len(self.tokenizer.vocab)
|
||||||
|
new_size = len(new_vocab)
|
||||||
|
print(f"[CLEAN] Vocabulary reduced: {old_size} → {new_size}")
|
||||||
|
|
||||||
|
# Replace tokenizer vocab
|
||||||
|
self.tokenizer.vocab = new_vocab
|
||||||
|
self.tokenizer.inv_vocab = {v: k for k, v in new_vocab.items()}
|
||||||
|
|
||||||
|
# Reinitialize the model to reflect the new vocab
|
||||||
|
self.rebuild_model_if_needed()
|
||||||
|
|
||||||
|
# Optionally: Save cleaned vocab
|
||||||
|
with open("tokenizer_vocab.txt", "w", encoding="utf-8") as f:
|
||||||
|
for token in new_vocab:
|
||||||
|
f.write(f"{token}\n")
|
||||||
|
print("[CLEAN] Vocab written to tokenizer_vocab.txt")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user