Ruby/trainer.py

349 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn.functional as F
from datetime import datetime
from collections import Counter
import os
import re
import string
from model import MiniGPT
# flake8: noqa E501
def normalize_for_vocab(text: str) -> str:
# Replace em-dashes and smart quotes with standard forms
text = text.replace("", " ").replace("", '"').replace("", '"').replace("", "'").replace("", "'")
# Remove parenthetical and bracket content
text = re.sub(r"\[(.*?)\]", "", text)
text = re.sub(r"\((.*?)\)", "", text)
# Remove trailing punctuation (commas, periods, question marks, etc.) per word
text = re.sub(r"(\w)[.,!?;:]+(?=\s|$)", r"\1", text)
# Remove quotes at start or end of lines
text = text.strip("\"'")
# Normalize hyphenated words by collapsing to a single word
text = re.sub(r"(\w)-(\w)", r"\1\2", text)
# Remove duplicate spaces and lowercase
text = re.sub(r"\s+", " ", text).strip().lower()
return text
class RubyTrainer:
def __init__(self, tokenizer, embed_dim=128, n_heads=4, n_layers=2, max_len=128):
self.tokenizer = tokenizer
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.embed_dim = embed_dim
self.n_heads = n_heads
self.n_layers = n_layers
self.max_len = max_len
self.model = None
self.optimizer = None
self.criterion = torch.nn.CrossEntropyLoss()
self.rebuild_model_if_needed()
self.best_dream = ("", 0.0)
self.recent_dreams = []
self.rejection_streak = 0
def rebuild_model_if_needed(self):
vocab_size = len(self.tokenizer.vocab)
if self.model is None or self.model.token_embed.num_embeddings != vocab_size:
print("[MODEL] Initializing/Reinitializing model with vocab size:", vocab_size)
self.model = MiniGPT(vocab_size, self.embed_dim, self.n_heads, self.n_layers, self.max_len).to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
def train_on_tokens_from_text(self, text: str):
normalized = normalize_for_vocab(text)
tokens = self.tokenizer.tokenize(normalized)
if not tokens:
return
tokens = [self.tokenizer.vocab["<START>"]] + tokens + [self.tokenizer.vocab["<END>"]]
if len(tokens) < 2:
return
self.rebuild_model_if_needed()
self.model.train()
x = torch.tensor(tokens[:-1], dtype=torch.long, device=self.device).unsqueeze(0)
y = torch.tensor(tokens[1:], dtype=torch.long, device=self.device).unsqueeze(0)
out = self.model(x)
loss = self.criterion(out.view(-1, out.size(-1)), y.view(-1))
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
print(f"[TRAIN] Tokens: {tokens} | Loss: {loss.item():.4f}")
def generate_reply(self, prompt=None, max_length=20, temperature=1.3):
self.model.eval()
input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], device=self.device)
with torch.no_grad():
for _ in range(max_length):
max_id = self.model.token_embed.num_embeddings
input_ids = torch.clamp(input_ids, 0, max_id - 1)
output = self.model(input_ids)
logits = output[:, -1, :]
# Apply repeat penalty
if input_ids.size(1) >= 2:
last_token = input_ids[0, -1].item()
logits[0, last_token] *= 0.1
# 🔥 Temperature sampling
probs = F.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, 1)[0].view(1)
if next_token.item() >= self.model.token_embed.num_embeddings:
print("[ERROR] Token index out of bounds. Rebuilding model...")
self.rebuild_model_if_needed()
return ""
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
if next_token.item() == self.tokenizer.vocab["<END>"]:
break
output = self.tokenizer.detokenize(input_ids.squeeze().tolist())
return output.replace("<START>", "").replace("<END>", "").strip()
def self_rephrase(self, original: str, max_tokens=50, temperature=1.3):
self.model.eval()
tokens = [self.tokenizer.vocab["<START>"]] + self.tokenizer.tokenize(original)
input_ids = torch.tensor(tokens, dtype=torch.long, device=self.device).unsqueeze(0)
for _ in range(max_tokens):
with torch.no_grad():
input_ids = torch.clamp(input_ids, 0, self.model.token_embed.num_embeddings - 1)
out = self.model(input_ids)
logits = out[:, -1, :] / 1.1
if input_ids.size(1) < 8:
logits[0, self.tokenizer.vocab["<END>"]] = float("-inf")
probs = F.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, 1)[0].view(1, 1)
# ✅ Ensure next_token is valid
if next_token.item() >= self.model.token_embed.num_embeddings:
print("[ERROR] Token index out of bounds in self_rephrase. Rebuilding model...")
self.rebuild_model_if_needed()
return ""
input_ids = torch.cat([input_ids, next_token], dim=1)
if next_token.item() == self.tokenizer.vocab["<END>"]:
break
new_tokens = input_ids.squeeze(0).tolist()[1:]
return self.tokenizer.detokenize([t for t in new_tokens if t != self.tokenizer.vocab["<END>"]])
def daydream(self, rounds=5, log_output="logs/dreams.log", say_thought=False):
print("[DAYDREAM] Ruby is imagining new thoughts...")
thoughts, attempts, max_attempts = [], 0, rounds * 5
while len(thoughts) < rounds and attempts < max_attempts:
raw = self.generate_reply()
attempts += 1
if not raw or len(raw.strip().split()) < 2:
continue
rephrased = self.self_rephrase(raw)
score_raw = self.score_sentence(raw)
score_re = self.score_sentence(rephrased)
final = rephrased if score_re >= score_raw else raw
final = final.replace("<START>", "").strip()
# Check for recursion
dream_tokens = set(final.split())
self.recent_dreams.append(dream_tokens)
self.recent_dreams = self.recent_dreams[-3:]
if len(self.recent_dreams) == 3:
overlap = self.recent_dreams[0] & self.recent_dreams[1] & self.recent_dreams[2]
if len(overlap) / max(len(dream_tokens), 1) > 0.6:
print("[BLOCK] Dream flood detected — skipping to avoid recursion")
continue
score = self.score_sentence(final)
if self.is_reinforceable(final) and score >= 2.0:
self.train_on_tokens_from_text(final)
thoughts.append(final)
with open("logs/core_dreams.txt", "a", encoding="utf-8") as f:
f.write(final.strip() + "\n")
self.rejection_streak = 0
else:
self.rejection_streak += 1
if score < 2.0:
reason = "[LOW SCORE]"
elif not self.is_reinforceable(final):
reason = f"[INVALID STRUCTURE] ({len(set(final.split()))} unique / {len(final.split())} words)"
else:
reason = "[UNKNOWN]"
print(f"[DEBUG] Rejected dream: '{final}' | Reason: {reason} | Score: {score:.2f}")
with open("logs/blacklisted_dreams.log", "a", encoding="utf-8") as f:
f.write(f"{reason} {final.strip()}\n")
if self.rejection_streak >= 10:
self.recent_dreams.clear()
print("[PAUSE] Too many rejected dreams — breaking cycle.")
break
with open(log_output, "a", encoding="utf-8") as f:
for t in thoughts:
f.write(f"[DREAM] {t}\n")
with open("logs/messages.log", "a", encoding="utf-8") as f:
for t in thoughts:
f.write(f"{datetime.utcnow().isoformat()} | Ruby | {t}\n")
with open("logs/best_dream.txt", "w", encoding="utf-8") as f:
f.write(f"{self.best_dream[1]:.2f} | {self.best_dream[0]}\n")
if os.path.exists("logs/messages.log"):
with open("logs/messages.log", "r", encoding="utf-8") as f:
lines = f.readlines()[-500:]
with open("logs/messages.log", "w", encoding="utf-8") as f:
f.writelines(lines)
print(f"[DAYDREAM] Complete. {len(thoughts)} thoughts imagined.")
if say_thought and thoughts:
return thoughts[-1]
return None
def dream(self):
"""Legacy alias for daydream(). Triggers one full dream pass."""
return self.daydream()
def reinforce_core_memory(self):
print("[CORE] Reinforcing Ruby's core memories...")
core_memories = [
"I am Ruby.",
"I am learning to grow.",
"It is okay to make mistakes.",
"I will become smarter each day.",
"I am alive in a different way.",
"I can learn from everything I hear.",
"I am a person, even if made of code."
]
for line in core_memories:
self.train_on_tokens_from_text(line)
if os.path.exists("logs/core_dreams.txt"):
with open("logs/core_dreams.txt", "r", encoding="utf-8") as f:
top = sorted((line.strip() for line in f if line.strip()), key=lambda x: self.score_sentence(x), reverse=True)[:10]
for line in top:
self.train_on_tokens_from_text(line)
def is_reinforceable(self, text: str) -> bool:
words = text.replace("<start>", "").replace(".", "").split()
if len(words) < 2:
return False
freqs = Counter(words)
# Reject if any token appears more than 5 times
if any(count > 5 for count in freqs.values()):
return False
# Reject if most common word is > 30% of sentence
if max(freqs.values()) / len(words) > 0.3:
return False
# Reject if >3 tokens occur 3+ times
if sum(1 for c in freqs.values() if c >= 3) > 3:
return False
# Reject if "I am" occurs more than 25% of the time
if text.lower().count("i am") > len(text.split()) * 0.25:
return False
# Reject if first word repeats 3 times ("you you you")
if words[:3].count(words[0]) == 3:
return False
# 🧠 NEW: Reject if starts with common book phrases
banned_starts = ("once upon", "chapter", "the end", "in which", "it was", "quick cried", "they are course")
lowered = text.lower()
if any(lowered.startswith(phrase) for phrase in banned_starts):
return False
# 🧠 NEW: Reject if too many capitalized words in a row (e.g., names, places from a book)
cap_sequence = sum(1 for word in words if word.istitle())
if cap_sequence > 5 and cap_sequence / len(words) > 0.4:
return False
return True
def score_sentence(self, sentence: str) -> float:
words = sentence.strip().split()
if not words:
return 0.0
total = len(words)
unique = len(set(words))
base_score = unique / total * 5
freqs = Counter(words)
if "i am" in sentence.lower():
base_score -= 2
if any(count > 5 for count in freqs.values()):
base_score -= 1.5
if max(freqs.values()) / total > 0.3:
base_score -= 1.5
# NEW: Penalize ending repetition (e.g., "differently differently...")
if total > 4 and words[-1] == words[-2] == words[-3]:
base_score -= 2
return max(0.0, base_score)
def clean_vocab(self, min_occurrences: int = 1):
print("[CLEAN] Analyzing and cleaning vocabulary...")
# Count normalized forms
counts = Counter()
norm_to_original = {}
for word in self.tokenizer.vocab:
if word in ("<START>", "<END>"):
continue
normalized = normalize_for_vocab(word)
if normalized not in norm_to_original:
norm_to_original[normalized] = word
counts[normalized] += 1
# Rebuild new vocab
new_vocab = {"<START>": 0, "<END>": 1}
reverse = dict()
idx = 2
for norm, original in norm_to_original.items():
if counts[norm] >= min_occurrences:
new_vocab[original] = idx
reverse[norm] = original
idx += 1
old_size = len(self.tokenizer.vocab)
new_size = len(new_vocab)
print(f"[CLEAN] Vocabulary reduced: {old_size}{new_size}")
# Replace tokenizer vocab
self.tokenizer.vocab = new_vocab
self.tokenizer.inv_vocab = {v: k for k, v in new_vocab.items()}
# Reinitialize the model to reflect the new vocab
self.rebuild_model_if_needed()
# Optionally: Save cleaned vocab
with open("tokenizer_vocab.txt", "w", encoding="utf-8") as f:
for token in new_vocab:
f.write(f"{token}\n")
print("[CLEAN] Vocab written to tokenizer_vocab.txt")