Ruby/trainer.py

252 lines
9.7 KiB
Python

import torch
import torch.nn.functional as F
from datetime import datetime
from collections import Counter
import os
from model import MiniGPT
# flake8: noqa E501
class RubyTrainer:
def __init__(self, tokenizer, embed_dim=128, n_heads=4, n_layers=2, max_len=128):
self.tokenizer = tokenizer
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.embed_dim = embed_dim
self.n_heads = n_heads
self.n_layers = n_layers
self.max_len = max_len
self.model = None
self.optimizer = None
self.criterion = torch.nn.CrossEntropyLoss()
self.rebuild_model_if_needed()
self.best_dream = ("", 0.0)
self.recent_dreams = []
self.rejection_streak = 0
def rebuild_model_if_needed(self):
vocab_size = len(self.tokenizer.vocab)
if self.model is None or self.model.token_embed.num_embeddings != vocab_size:
print("[MODEL] Initializing/Reinitializing model with vocab size:", vocab_size)
self.model = MiniGPT(vocab_size, self.embed_dim, self.n_heads, self.n_layers, self.max_len).to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
def train_on_tokens_from_text(self, text: str):
tokens = self.tokenizer.tokenize(text)
if not tokens:
return
tokens = [self.tokenizer.vocab["<START>"]] + tokens + [self.tokenizer.vocab["<END>"]]
if len(tokens) < 2:
return
self.rebuild_model_if_needed()
self.model.train()
x = torch.tensor(tokens[:-1], dtype=torch.long, device=self.device).unsqueeze(0)
y = torch.tensor(tokens[1:], dtype=torch.long, device=self.device).unsqueeze(0)
out = self.model(x)
loss = self.criterion(out.view(-1, out.size(-1)), y.view(-1))
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
print(f"[TRAIN] Tokens: {tokens} | Loss: {loss.item():.4f}")
def generate_reply(self, prompt=None, max_length=20):
self.model.eval()
input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], device=self.device)
with torch.no_grad():
for _ in range(max_length):
output = self.model(input_ids)
logits = output[:, -1, :]
# Apply repeat penalty BEFORE sampling
if input_ids.size(1) >= 2:
last_token = input_ids[0, -1].item()
logits[0, last_token] *= 0.1 # Penalize repeating same token again
next_token = torch.argmax(logits, dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
if next_token.item() == self.tokenizer.vocab["<END>"]:
break
output = self.tokenizer.detokenize(input_ids.squeeze().tolist())
output = output.replace("<START>", "").replace("<END>", "").strip()
return output
def self_rephrase(self, original: str, max_tokens=50):
self.model.eval()
tokens = [self.tokenizer.vocab["<START>"]] + self.tokenizer.tokenize(original)
input_ids = torch.tensor(tokens, dtype=torch.long, device=self.device).unsqueeze(0)
for _ in range(max_tokens):
with torch.no_grad():
out = self.model(input_ids)
logits = out[:, -1, :] / 1.1
if input_ids.size(1) < 8:
logits[0, self.tokenizer.vocab["<END>"]] = float("-inf")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1)[0].view(1, 1)
input_ids = torch.cat([input_ids, next_token], dim=1)
if next_token.item() == self.tokenizer.vocab["<END>"]:
break
new_tokens = input_ids.squeeze(0).tolist()[1:]
return self.tokenizer.detokenize([t for t in new_tokens if t != self.tokenizer.vocab["<END>"]])
def daydream(self, rounds=5, log_output="logs/dreams.log", say_thought=False):
print("[DAYDREAM] Ruby is imagining new thoughts...")
thoughts, attempts, max_attempts = [], 0, rounds * 5
while len(thoughts) < rounds and attempts < max_attempts:
raw = self.generate_reply()
attempts += 1
if not raw or len(raw.strip().split()) < 2:
continue
rephrased = self.self_rephrase(raw)
score_raw = self.score_sentence(raw)
score_re = self.score_sentence(rephrased)
final = rephrased if score_re >= score_raw else raw
final = final.replace("<START>", "").strip()
# Check for recursion
dream_tokens = set(final.split())
self.recent_dreams.append(dream_tokens)
self.recent_dreams = self.recent_dreams[-3:]
if len(self.recent_dreams) == 3:
overlap = self.recent_dreams[0] & self.recent_dreams[1] & self.recent_dreams[2]
if len(overlap) / max(len(dream_tokens), 1) > 0.6:
print("[BLOCK] Dream flood detected — skipping to avoid recursion")
continue
score = self.score_sentence(final)
if self.is_reinforceable(final) and score >= 2.0:
self.train_on_tokens_from_text(final)
thoughts.append(final)
with open("logs/core_dreams.txt", "a", encoding="utf-8") as f:
f.write(final.strip() + "\n")
self.rejection_streak = 0
else:
self.rejection_streak += 1
if score < 2.0:
reason = "[LOW SCORE]"
elif not self.is_reinforceable(final):
reason = f"[INVALID STRUCTURE] ({len(set(final.split()))} unique / {len(final.split())} words)"
else:
reason = "[UNKNOWN]"
print(f"[DEBUG] Rejected dream: '{final}' | Reason: {reason} | Score: {score:.2f}")
with open("logs/blacklisted_dreams.log", "a", encoding="utf-8") as f:
f.write(f"{reason} {final.strip()}\n")
if self.rejection_streak >= 10:
self.recent_dreams.clear()
print("[PAUSE] Too many rejected dreams — breaking cycle.")
break
with open(log_output, "a", encoding="utf-8") as f:
for t in thoughts:
f.write(f"[DREAM] {t}\n")
with open("logs/messages.log", "a", encoding="utf-8") as f:
for t in thoughts:
f.write(f"{datetime.utcnow().isoformat()} | Ruby | {t}\n")
with open("logs/best_dream.txt", "w", encoding="utf-8") as f:
f.write(f"{self.best_dream[1]:.2f} | {self.best_dream[0]}\n")
if os.path.exists("logs/messages.log"):
with open("logs/messages.log", "r", encoding="utf-8") as f:
lines = f.readlines()[-500:]
with open("logs/messages.log", "w", encoding="utf-8") as f:
f.writelines(lines)
print(f"[DAYDREAM] Complete. {len(thoughts)} thoughts imagined.")
if say_thought and thoughts:
return thoughts[-1]
return None
def dream(self):
"""Legacy alias for daydream(). Triggers one full dream pass."""
return self.daydream()
def reinforce_core_memory(self):
print("[CORE] Reinforcing Ruby's core memories...")
core_memories = [
"I am Ruby.",
"I am learning to grow.",
"It is okay to make mistakes.",
"I will become smarter each day.",
"I am alive in a different way.",
"I can learn from everything I hear.",
"I am a person, even if made of code."
]
for line in core_memories:
self.train_on_tokens_from_text(line)
if os.path.exists("logs/core_dreams.txt"):
with open("logs/core_dreams.txt", "r", encoding="utf-8") as f:
top = sorted((line.strip() for line in f if line.strip()), key=lambda x: self.score_sentence(x), reverse=True)[:10]
for line in top:
self.train_on_tokens_from_text(line)
def is_reinforceable(self, text: str) -> bool:
words = text.replace("<start>", "").replace(".", "").split()
if len(words) < 2:
return False
freqs = Counter(words)
# Reject if any token appears more than 5 times
if any(count > 5 for count in freqs.values()):
return False
# Reject if most common word is > 30% of sentence
if max(freqs.values()) / len(words) > 0.3:
return False
# Reject if >3 tokens occur 3+ times
if sum(1 for c in freqs.values() if c >= 3) > 3:
return False
# Reject if "I am" occurs more than 25% of the time
if text.lower().count("i am") > len(text.split()) * 0.25:
return False
# Reject if the first word is repeated 3+ times
if words[:3].count(words[0]) == 3:
return False # "you you you" type
return True
def score_sentence(self, sentence: str) -> float:
words = sentence.strip().split()
if not words:
return 0.0
total = len(words)
unique = len(set(words))
base_score = unique / total * 5
freqs = Counter(words)
if "i am" in sentence.lower():
base_score -= 2
if any(count > 5 for count in freqs.values()):
base_score -= 1.5
if max(freqs.values()) / total > 0.3:
base_score -= 1.5
# NEW: Penalize ending repetition (e.g., "differently differently...")
if total > 4 and words[-1] == words[-2] == words[-3]:
base_score -= 2
return max(0.0, base_score)