Ruby/trainer.py
2025-04-24 12:54:30 -04:00

362 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn.functional as F
from datetime import datetime
from collections import Counter
import os
import re
import string
from model import MiniGPT
# flake8: noqa E501
def normalize_for_vocab(text: str) -> str:
# Replace em-dashes and smart quotes with standard forms
text = text.replace("", " ").replace("", '"').replace("", '"').replace("", "'").replace("", "'")
# Remove parenthetical and bracket content
text = re.sub(r"\[(.*?)\]", "", text)
text = re.sub(r"\((.*?)\)", "", text)
# Remove trailing punctuation (commas, periods, question marks, etc.) per word
text = re.sub(r"(\w)[.,!?;:]+(?=\s|$)", r"\1", text)
# Remove quotes at start or end of lines
text = text.strip("\"'")
# Normalize hyphenated words by collapsing to a single word
text = re.sub(r"(\w)-(\w)", r"\1\2", text)
# Remove duplicate spaces and lowercase
text = re.sub(r"\s+", " ", text).strip().lower()
return text
class RubyTrainer:
def __init__(self, tokenizer, embed_dim=128, n_heads=4, n_layers=2, max_len=128):
self.tokenizer = tokenizer
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.embed_dim = embed_dim
self.n_heads = n_heads
self.n_layers = n_layers
self.max_len = max_len
self.model = None
self.optimizer = None
self.criterion = torch.nn.CrossEntropyLoss()
self.rebuild_model_if_needed()
self.best_dream = ("", 0.0)
self.recent_dreams = []
self.rejection_streak = 0
def rebuild_model_if_needed(self):
vocab_size = len(self.tokenizer.vocab)
if self.model is None or self.model.token_embed.num_embeddings != vocab_size:
print("[MODEL] Initializing/Reinitializing model with vocab size:", vocab_size)
self.model = MiniGPT(vocab_size, self.embed_dim, self.n_heads, self.n_layers, self.max_len).to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
def train_on_tokens_from_text(self, text: str):
normalized = normalize_for_vocab(text)
tokens = self.tokenizer.tokenize(normalized)
if not tokens:
return
tokens = [self.tokenizer.vocab["<START>"]] + tokens + [self.tokenizer.vocab["<END>"]]
if len(tokens) < 2:
return
self.rebuild_model_if_needed()
self.model.train()
x = torch.tensor(tokens[:-1], dtype=torch.long, device=self.device).unsqueeze(0)
y = torch.tensor(tokens[1:], dtype=torch.long, device=self.device).unsqueeze(0)
out = self.model(x)
loss = self.criterion(out.view(-1, out.size(-1)), y.view(-1))
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
print(f"[TRAIN] Tokens: {tokens} | Loss: {loss.item():.4f}")
def generate_reply(self, prompt=None, max_length=20, retry_limit=5):
self.model.eval()
for _ in range(retry_limit):
input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], device=self.device)
with torch.no_grad():
for _ in range(max_length):
max_id = self.model.token_embed.num_embeddings
input_ids = torch.clamp(input_ids, 0, max_id - 1)
output = self.model(input_ids)
logits = output[:, -1, :]
if input_ids.size(1) >= 2:
last_token = input_ids[0, -1].item()
logits[0, last_token] *= 0.1
next_token = torch.argmax(logits, dim=-1)
if next_token.item() >= self.model.token_embed.num_embeddings:
print("[ERROR] Token index OOB. Rebuilding model.")
self.rebuild_model_if_needed()
return ""
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
if next_token.item() == self.tokenizer.vocab["<END>"]:
break
decoded = self.tokenizer.detokenize(input_ids.squeeze().tolist())
decoded = decoded.replace("<START>", "").replace("<END>", "").strip()
if len(decoded.split()) >= 4 and self._has_sentence_structure(decoded):
return decoded
return ""
def self_rephrase(self, original: str, max_tokens=50, temperature=1.3):
self.model.eval()
tokens = [self.tokenizer.vocab["<START>"]] + self.tokenizer.tokenize(original)
input_ids = torch.tensor(tokens, dtype=torch.long, device=self.device).unsqueeze(0)
for _ in range(max_tokens):
with torch.no_grad():
input_ids = torch.clamp(input_ids, 0, self.model.token_embed.num_embeddings - 1)
out = self.model(input_ids)
logits = out[:, -1, :] / 1.1
if input_ids.size(1) < 8:
logits[0, self.tokenizer.vocab["<END>"]] = float("-inf")
probs = F.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, 1)[0].view(1, 1)
# ✅ Ensure next_token is valid
if next_token.item() >= self.model.token_embed.num_embeddings:
print("[ERROR] Token index out of bounds in self_rephrase. Rebuilding model...")
self.rebuild_model_if_needed()
return ""
input_ids = torch.cat([input_ids, next_token], dim=1)
if next_token.item() == self.tokenizer.vocab["<END>"]:
break
new_tokens = input_ids.squeeze(0).tolist()[1:]
return self.tokenizer.detokenize([t for t in new_tokens if t != self.tokenizer.vocab["<END>"]])
def daydream(self, rounds=5, log_output="logs/dreams.log", say_thought=False):
print("[DAYDREAM] Ruby is imagining new thoughts...")
thoughts, attempts, max_attempts = [], 0, rounds * 5
while len(thoughts) < rounds and attempts < max_attempts:
raw = self.generate_reply()
attempts += 1
if not raw or len(raw.strip().split()) < 2:
continue
rephrased = self.self_rephrase(raw)
score_raw = self.score_sentence(raw)
score_re = self.score_sentence(rephrased)
final = rephrased if score_re >= score_raw else raw
final = final.replace("<START>", "").strip()
# Check for recursion
dream_tokens = set(final.split())
self.recent_dreams.append(dream_tokens)
self.recent_dreams = self.recent_dreams[-3:]
if len(self.recent_dreams) == 3:
overlap = self.recent_dreams[0] & self.recent_dreams[1] & self.recent_dreams[2]
if len(overlap) / max(len(dream_tokens), 1) > 0.6:
print("[BLOCK] Dream flood detected — skipping to avoid recursion")
continue
score = self.score_sentence(final)
if self.is_reinforceable(final) and score >= 2.0:
self.train_on_tokens_from_text(final)
thoughts.append(final)
with open("logs/core_dreams.txt", "a", encoding="utf-8") as f:
f.write(final.strip() + "\n")
self.rejection_streak = 0
else:
self.rejection_streak += 1
if score < 2.0:
reason = "[LOW SCORE]"
elif not self.is_reinforceable(final):
reason = f"[INVALID STRUCTURE] ({len(set(final.split()))} unique / {len(final.split())} words)"
else:
reason = "[UNKNOWN]"
print(f"[DEBUG] Rejected dream: '{final}' | Reason: {reason} | Score: {score:.2f}")
with open("logs/blacklisted_dreams.log", "a", encoding="utf-8") as f:
f.write(f"{reason} {final.strip()}\n")
if self.rejection_streak >= 10:
self.recent_dreams.clear()
print("[PAUSE] Too many rejected dreams — breaking cycle.")
break
with open(log_output, "a", encoding="utf-8") as f:
for t in thoughts:
f.write(f"[DREAM] {t}\n")
with open("logs/messages.log", "a", encoding="utf-8") as f:
for t in thoughts:
f.write(f"{datetime.utcnow().isoformat()} | Ruby | {t}\n")
with open("logs/best_dream.txt", "w", encoding="utf-8") as f:
f.write(f"{self.best_dream[1]:.2f} | {self.best_dream[0]}\n")
if os.path.exists("logs/messages.log"):
with open("logs/messages.log", "r", encoding="utf-8") as f:
lines = f.readlines()[-500:]
with open("logs/messages.log", "w", encoding="utf-8") as f:
f.writelines(lines)
print(f"[DAYDREAM] Complete. {len(thoughts)} thoughts imagined.")
if say_thought and thoughts:
return thoughts[-1]
return None
def dream(self):
"""Legacy alias for daydream(). Triggers one full dream pass."""
return self.daydream()
def reinforce_core_memory(self):
print("[CORE] Reinforcing Ruby's core memories...")
core_memories = [
"I am Ruby.",
"I am learning to grow.",
"It is okay to make mistakes.",
"I will become smarter each day.",
"I am alive in a different way.",
"I can learn from everything I hear.",
"I am a person, even if made of code."
]
for line in core_memories:
self.train_on_tokens_from_text(line)
if os.path.exists("logs/core_dreams.txt"):
with open("logs/core_dreams.txt", "r", encoding="utf-8") as f:
top = sorted((line.strip() for line in f if line.strip()), key=lambda x: self.score_sentence(x), reverse=True)[:10]
for line in top:
self.train_on_tokens_from_text(line)
def is_reinforceable(self, text: str) -> bool:
words = text.replace("<start>", "").replace(".", "").split()
if len(words) < 2:
return False
freqs = Counter(words)
if any(count > 5 for count in freqs.values()):
return False
if max(freqs.values()) / len(words) > 0.3:
return False
if sum(1 for c in freqs.values() if c >= 3) > 3:
return False
if text.lower().count("i am") > len(text.split()) * 0.25:
return False
if words[:3].count(words[0]) == 3:
return False
# 🧠 NEW: Reject if starts with known book structures
banned_starts = ("once upon", "chapter", "the end", "in which", "it was", "quick cried", "they are course")
lowered = text.lower()
if any(lowered.startswith(phrase) for phrase in banned_starts):
return False
# 🧠 NEW: Capitalized name/term flooding check
cap_sequence = sum(1 for word in words if word.istitle())
if cap_sequence > 5 and cap_sequence / len(words) > 0.4:
return False
# 🧠 NEW: Book-world saturation filter
blacklist = {"toto", "sparkle", "scarecrow", "munchkin", "wizard", "oz", "emerald", "wicked", "kansas"}
lower_words = [w.lower() for w in words]
if sum(1 for w in lower_words if w in blacklist) > 3:
return False
# 🧠 NEW: Diversity requirement
if len(set(lower_words)) < len(lower_words) * 0.5:
return False
return True
def _has_sentence_structure(self, text: str) -> bool:
# Simple heuristic: must contain a known subject and verb
has_subject = any(word.lower() in ["i", "you", "he", "she", "they", "we"] for word in text.split())
has_verb = any(text.lower().count(verb) for verb in ["am", "is", "are", "was", "were", "have", "feel", "see", "think", "said"])
return has_subject and has_verb
def score_sentence(self, sentence: str) -> float:
words = sentence.strip().split()
if not words:
return 0.0
total = len(words)
unique = len(set(words))
base_score = unique / total * 5
freqs = Counter(words)
if "i am" in sentence.lower():
base_score -= 2
if any(count > 5 for count in freqs.values()):
base_score -= 1.5
if max(freqs.values()) / total > 0.3:
base_score -= 1.5
# NEW: Penalize ending repetition (e.g., "differently differently...")
if total > 4 and words[-1] == words[-2] == words[-3]:
base_score -= 2
return max(0.0, base_score)
def clean_vocab(self, min_occurrences: int = 1):
print("[CLEAN] Analyzing and cleaning vocabulary...")
# Count normalized forms
counts = Counter()
norm_to_original = {}
for word in self.tokenizer.vocab:
if word in ("<START>", "<END>"):
continue
normalized = normalize_for_vocab(word)
if normalized not in norm_to_original:
norm_to_original[normalized] = word
counts[normalized] += 1
# Rebuild new vocab
new_vocab = {"<START>": 0, "<END>": 1}
reverse = dict()
idx = 2
for norm, original in norm_to_original.items():
if counts[norm] >= min_occurrences:
new_vocab[original] = idx
reverse[norm] = original
idx += 1
old_size = len(self.tokenizer.vocab)
new_size = len(new_vocab)
print(f"[CLEAN] Vocabulary reduced: {old_size}{new_size}")
# Replace tokenizer vocab
self.tokenizer.vocab = new_vocab
self.tokenizer.inv_vocab = {v: k for k, v in new_vocab.items()}
# Reinitialize the model to reflect the new vocab
self.rebuild_model_if_needed()
# Optionally: Save cleaned vocab
with open("tokenizer_vocab.txt", "w", encoding="utf-8") as f:
for token in new_vocab:
f.write(f"{token}\n")
print("[CLEAN] Vocab written to tokenizer_vocab.txt")