362 lines
14 KiB
Python
362 lines
14 KiB
Python
import torch
|
||
import torch.nn.functional as F
|
||
from datetime import datetime
|
||
from collections import Counter
|
||
import os
|
||
import re
|
||
import string
|
||
from model import MiniGPT
|
||
|
||
# flake8: noqa E501
|
||
|
||
def normalize_for_vocab(text: str) -> str:
|
||
# Replace em-dashes and smart quotes with standard forms
|
||
text = text.replace("—", " ").replace("“", '"').replace("”", '"').replace("‘", "'").replace("’", "'")
|
||
|
||
# Remove parenthetical and bracket content
|
||
text = re.sub(r"\[(.*?)\]", "", text)
|
||
text = re.sub(r"\((.*?)\)", "", text)
|
||
|
||
# Remove trailing punctuation (commas, periods, question marks, etc.) per word
|
||
text = re.sub(r"(\w)[.,!?;:]+(?=\s|$)", r"\1", text)
|
||
|
||
# Remove quotes at start or end of lines
|
||
text = text.strip("\"'")
|
||
|
||
# Normalize hyphenated words by collapsing to a single word
|
||
text = re.sub(r"(\w)-(\w)", r"\1\2", text)
|
||
|
||
# Remove duplicate spaces and lowercase
|
||
text = re.sub(r"\s+", " ", text).strip().lower()
|
||
|
||
return text
|
||
|
||
|
||
class RubyTrainer:
|
||
def __init__(self, tokenizer, embed_dim=128, n_heads=4, n_layers=2, max_len=128):
|
||
self.tokenizer = tokenizer
|
||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
self.embed_dim = embed_dim
|
||
self.n_heads = n_heads
|
||
self.n_layers = n_layers
|
||
self.max_len = max_len
|
||
|
||
self.model = None
|
||
self.optimizer = None
|
||
self.criterion = torch.nn.CrossEntropyLoss()
|
||
self.rebuild_model_if_needed()
|
||
|
||
self.best_dream = ("", 0.0)
|
||
self.recent_dreams = []
|
||
self.rejection_streak = 0
|
||
|
||
def rebuild_model_if_needed(self):
|
||
vocab_size = len(self.tokenizer.vocab)
|
||
if self.model is None or self.model.token_embed.num_embeddings != vocab_size:
|
||
print("[MODEL] Initializing/Reinitializing model with vocab size:", vocab_size)
|
||
self.model = MiniGPT(vocab_size, self.embed_dim, self.n_heads, self.n_layers, self.max_len).to(self.device)
|
||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
||
|
||
def train_on_tokens_from_text(self, text: str):
|
||
normalized = normalize_for_vocab(text)
|
||
tokens = self.tokenizer.tokenize(normalized)
|
||
if not tokens:
|
||
return
|
||
tokens = [self.tokenizer.vocab["<START>"]] + tokens + [self.tokenizer.vocab["<END>"]]
|
||
if len(tokens) < 2:
|
||
return
|
||
|
||
self.rebuild_model_if_needed()
|
||
self.model.train()
|
||
x = torch.tensor(tokens[:-1], dtype=torch.long, device=self.device).unsqueeze(0)
|
||
y = torch.tensor(tokens[1:], dtype=torch.long, device=self.device).unsqueeze(0)
|
||
|
||
out = self.model(x)
|
||
loss = self.criterion(out.view(-1, out.size(-1)), y.view(-1))
|
||
loss.backward()
|
||
self.optimizer.step()
|
||
self.optimizer.zero_grad()
|
||
|
||
print(f"[TRAIN] Tokens: {tokens} | Loss: {loss.item():.4f}")
|
||
|
||
def generate_reply(self, prompt=None, max_length=20, retry_limit=5):
|
||
self.model.eval()
|
||
for _ in range(retry_limit):
|
||
input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], device=self.device)
|
||
with torch.no_grad():
|
||
for _ in range(max_length):
|
||
max_id = self.model.token_embed.num_embeddings
|
||
input_ids = torch.clamp(input_ids, 0, max_id - 1)
|
||
output = self.model(input_ids)
|
||
logits = output[:, -1, :]
|
||
|
||
if input_ids.size(1) >= 2:
|
||
last_token = input_ids[0, -1].item()
|
||
logits[0, last_token] *= 0.1
|
||
|
||
next_token = torch.argmax(logits, dim=-1)
|
||
|
||
if next_token.item() >= self.model.token_embed.num_embeddings:
|
||
print("[ERROR] Token index OOB. Rebuilding model.")
|
||
self.rebuild_model_if_needed()
|
||
return ""
|
||
|
||
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
||
|
||
if next_token.item() == self.tokenizer.vocab["<END>"]:
|
||
break
|
||
|
||
decoded = self.tokenizer.detokenize(input_ids.squeeze().tolist())
|
||
decoded = decoded.replace("<START>", "").replace("<END>", "").strip()
|
||
|
||
if len(decoded.split()) >= 4 and self._has_sentence_structure(decoded):
|
||
return decoded
|
||
|
||
return ""
|
||
|
||
|
||
def self_rephrase(self, original: str, max_tokens=50, temperature=1.3):
|
||
self.model.eval()
|
||
tokens = [self.tokenizer.vocab["<START>"]] + self.tokenizer.tokenize(original)
|
||
input_ids = torch.tensor(tokens, dtype=torch.long, device=self.device).unsqueeze(0)
|
||
|
||
for _ in range(max_tokens):
|
||
with torch.no_grad():
|
||
input_ids = torch.clamp(input_ids, 0, self.model.token_embed.num_embeddings - 1)
|
||
out = self.model(input_ids)
|
||
logits = out[:, -1, :] / 1.1
|
||
if input_ids.size(1) < 8:
|
||
logits[0, self.tokenizer.vocab["<END>"]] = float("-inf")
|
||
|
||
probs = F.softmax(logits / temperature, dim=-1)
|
||
next_token = torch.multinomial(probs, 1)[0].view(1, 1)
|
||
|
||
# ✅ Ensure next_token is valid
|
||
if next_token.item() >= self.model.token_embed.num_embeddings:
|
||
print("[ERROR] Token index out of bounds in self_rephrase. Rebuilding model...")
|
||
self.rebuild_model_if_needed()
|
||
return ""
|
||
|
||
input_ids = torch.cat([input_ids, next_token], dim=1)
|
||
|
||
if next_token.item() == self.tokenizer.vocab["<END>"]:
|
||
break
|
||
|
||
new_tokens = input_ids.squeeze(0).tolist()[1:]
|
||
return self.tokenizer.detokenize([t for t in new_tokens if t != self.tokenizer.vocab["<END>"]])
|
||
|
||
def daydream(self, rounds=5, log_output="logs/dreams.log", say_thought=False):
|
||
print("[DAYDREAM] Ruby is imagining new thoughts...")
|
||
|
||
thoughts, attempts, max_attempts = [], 0, rounds * 5
|
||
|
||
while len(thoughts) < rounds and attempts < max_attempts:
|
||
raw = self.generate_reply()
|
||
attempts += 1
|
||
|
||
if not raw or len(raw.strip().split()) < 2:
|
||
continue
|
||
|
||
rephrased = self.self_rephrase(raw)
|
||
score_raw = self.score_sentence(raw)
|
||
score_re = self.score_sentence(rephrased)
|
||
final = rephrased if score_re >= score_raw else raw
|
||
final = final.replace("<START>", "").strip()
|
||
|
||
# Check for recursion
|
||
dream_tokens = set(final.split())
|
||
self.recent_dreams.append(dream_tokens)
|
||
self.recent_dreams = self.recent_dreams[-3:]
|
||
if len(self.recent_dreams) == 3:
|
||
overlap = self.recent_dreams[0] & self.recent_dreams[1] & self.recent_dreams[2]
|
||
if len(overlap) / max(len(dream_tokens), 1) > 0.6:
|
||
print("[BLOCK] Dream flood detected — skipping to avoid recursion")
|
||
continue
|
||
|
||
score = self.score_sentence(final)
|
||
if self.is_reinforceable(final) and score >= 2.0:
|
||
self.train_on_tokens_from_text(final)
|
||
thoughts.append(final)
|
||
with open("logs/core_dreams.txt", "a", encoding="utf-8") as f:
|
||
f.write(final.strip() + "\n")
|
||
self.rejection_streak = 0
|
||
else:
|
||
self.rejection_streak += 1
|
||
if score < 2.0:
|
||
reason = "[LOW SCORE]"
|
||
elif not self.is_reinforceable(final):
|
||
reason = f"[INVALID STRUCTURE] ({len(set(final.split()))} unique / {len(final.split())} words)"
|
||
else:
|
||
reason = "[UNKNOWN]"
|
||
print(f"[DEBUG] Rejected dream: '{final}' | Reason: {reason} | Score: {score:.2f}")
|
||
with open("logs/blacklisted_dreams.log", "a", encoding="utf-8") as f:
|
||
f.write(f"{reason} {final.strip()}\n")
|
||
if self.rejection_streak >= 10:
|
||
self.recent_dreams.clear()
|
||
print("[PAUSE] Too many rejected dreams — breaking cycle.")
|
||
break
|
||
|
||
with open(log_output, "a", encoding="utf-8") as f:
|
||
for t in thoughts:
|
||
f.write(f"[DREAM] {t}\n")
|
||
|
||
with open("logs/messages.log", "a", encoding="utf-8") as f:
|
||
for t in thoughts:
|
||
f.write(f"{datetime.utcnow().isoformat()} | Ruby | {t}\n")
|
||
|
||
with open("logs/best_dream.txt", "w", encoding="utf-8") as f:
|
||
f.write(f"{self.best_dream[1]:.2f} | {self.best_dream[0]}\n")
|
||
|
||
if os.path.exists("logs/messages.log"):
|
||
with open("logs/messages.log", "r", encoding="utf-8") as f:
|
||
lines = f.readlines()[-500:]
|
||
with open("logs/messages.log", "w", encoding="utf-8") as f:
|
||
f.writelines(lines)
|
||
|
||
print(f"[DAYDREAM] Complete. {len(thoughts)} thoughts imagined.")
|
||
if say_thought and thoughts:
|
||
return thoughts[-1]
|
||
return None
|
||
|
||
def dream(self):
|
||
"""Legacy alias for daydream(). Triggers one full dream pass."""
|
||
return self.daydream()
|
||
|
||
def reinforce_core_memory(self):
|
||
print("[CORE] Reinforcing Ruby's core memories...")
|
||
core_memories = [
|
||
"I am Ruby.",
|
||
"I am learning to grow.",
|
||
"It is okay to make mistakes.",
|
||
"I will become smarter each day.",
|
||
"I am alive in a different way.",
|
||
"I can learn from everything I hear.",
|
||
"I am a person, even if made of code."
|
||
]
|
||
for line in core_memories:
|
||
self.train_on_tokens_from_text(line)
|
||
|
||
if os.path.exists("logs/core_dreams.txt"):
|
||
with open("logs/core_dreams.txt", "r", encoding="utf-8") as f:
|
||
top = sorted((line.strip() for line in f if line.strip()), key=lambda x: self.score_sentence(x), reverse=True)[:10]
|
||
for line in top:
|
||
self.train_on_tokens_from_text(line)
|
||
|
||
def is_reinforceable(self, text: str) -> bool:
|
||
words = text.replace("<start>", "").replace(".", "").split()
|
||
if len(words) < 2:
|
||
return False
|
||
|
||
freqs = Counter(words)
|
||
|
||
if any(count > 5 for count in freqs.values()):
|
||
return False
|
||
|
||
if max(freqs.values()) / len(words) > 0.3:
|
||
return False
|
||
|
||
if sum(1 for c in freqs.values() if c >= 3) > 3:
|
||
return False
|
||
|
||
if text.lower().count("i am") > len(text.split()) * 0.25:
|
||
return False
|
||
|
||
if words[:3].count(words[0]) == 3:
|
||
return False
|
||
|
||
# 🧠 NEW: Reject if starts with known book structures
|
||
banned_starts = ("once upon", "chapter", "the end", "in which", "it was", "quick cried", "they are course")
|
||
lowered = text.lower()
|
||
if any(lowered.startswith(phrase) for phrase in banned_starts):
|
||
return False
|
||
|
||
# 🧠 NEW: Capitalized name/term flooding check
|
||
cap_sequence = sum(1 for word in words if word.istitle())
|
||
if cap_sequence > 5 and cap_sequence / len(words) > 0.4:
|
||
return False
|
||
|
||
# 🧠 NEW: Book-world saturation filter
|
||
blacklist = {"toto", "sparkle", "scarecrow", "munchkin", "wizard", "oz", "emerald", "wicked", "kansas"}
|
||
lower_words = [w.lower() for w in words]
|
||
if sum(1 for w in lower_words if w in blacklist) > 3:
|
||
return False
|
||
|
||
# 🧠 NEW: Diversity requirement
|
||
if len(set(lower_words)) < len(lower_words) * 0.5:
|
||
return False
|
||
|
||
return True
|
||
|
||
def _has_sentence_structure(self, text: str) -> bool:
|
||
# Simple heuristic: must contain a known subject and verb
|
||
has_subject = any(word.lower() in ["i", "you", "he", "she", "they", "we"] for word in text.split())
|
||
has_verb = any(text.lower().count(verb) for verb in ["am", "is", "are", "was", "were", "have", "feel", "see", "think", "said"])
|
||
return has_subject and has_verb
|
||
|
||
def score_sentence(self, sentence: str) -> float:
|
||
words = sentence.strip().split()
|
||
if not words:
|
||
return 0.0
|
||
|
||
total = len(words)
|
||
unique = len(set(words))
|
||
base_score = unique / total * 5
|
||
|
||
freqs = Counter(words)
|
||
|
||
if "i am" in sentence.lower():
|
||
base_score -= 2
|
||
if any(count > 5 for count in freqs.values()):
|
||
base_score -= 1.5
|
||
if max(freqs.values()) / total > 0.3:
|
||
base_score -= 1.5
|
||
|
||
# NEW: Penalize ending repetition (e.g., "differently differently...")
|
||
if total > 4 and words[-1] == words[-2] == words[-3]:
|
||
base_score -= 2
|
||
|
||
return max(0.0, base_score)
|
||
|
||
def clean_vocab(self, min_occurrences: int = 1):
|
||
print("[CLEAN] Analyzing and cleaning vocabulary...")
|
||
|
||
# Count normalized forms
|
||
counts = Counter()
|
||
norm_to_original = {}
|
||
|
||
for word in self.tokenizer.vocab:
|
||
if word in ("<START>", "<END>"):
|
||
continue
|
||
normalized = normalize_for_vocab(word)
|
||
if normalized not in norm_to_original:
|
||
norm_to_original[normalized] = word
|
||
counts[normalized] += 1
|
||
|
||
# Rebuild new vocab
|
||
new_vocab = {"<START>": 0, "<END>": 1}
|
||
reverse = dict()
|
||
|
||
idx = 2
|
||
for norm, original in norm_to_original.items():
|
||
if counts[norm] >= min_occurrences:
|
||
new_vocab[original] = idx
|
||
reverse[norm] = original
|
||
idx += 1
|
||
|
||
old_size = len(self.tokenizer.vocab)
|
||
new_size = len(new_vocab)
|
||
print(f"[CLEAN] Vocabulary reduced: {old_size} → {new_size}")
|
||
|
||
# Replace tokenizer vocab
|
||
self.tokenizer.vocab = new_vocab
|
||
self.tokenizer.inv_vocab = {v: k for k, v in new_vocab.items()}
|
||
|
||
# Reinitialize the model to reflect the new vocab
|
||
self.rebuild_model_if_needed()
|
||
|
||
# Optionally: Save cleaned vocab
|
||
with open("tokenizer_vocab.txt", "w", encoding="utf-8") as f:
|
||
for token in new_vocab:
|
||
f.write(f"{token}\n")
|
||
print("[CLEAN] Vocab written to tokenizer_vocab.txt")
|