Added lesson planning

This commit is contained in:
Dani 2025-04-24 12:54:30 -04:00
parent 7823ce1d5e
commit c73091c8c5
4 changed files with 120 additions and 32 deletions

4
.gitignore vendored
View File

@ -173,4 +173,6 @@ cython_debug/
/logs/best_dream.txt /logs/best_dream.txt
/.vscode/launch.json /.vscode/launch.json
/books /books
/readstate.txt /readstate.txt
/lessonstate.txt
/data/lessons.txt

70
lesson.py Normal file
View File

@ -0,0 +1,70 @@
import os
import asyncio
from datetime import datetime
class LessonModule:
def __init__(self, trainer, lesson_path="data/lessons.txt", log_path="logs/lesson.log", state_path="lessonstate.txt"):
self.trainer = trainer
self.lesson_path = lesson_path
self.log_path = log_path
self.state_path = state_path
self.current_index = 0
self.lessons = []
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
if os.path.exists(self.lesson_path):
with open(self.lesson_path, "r", encoding="utf-8") as f:
self.lessons = [line.strip() for line in f if line.strip()]
if os.path.exists(self.state_path):
try:
with open(self.state_path, "r", encoding="utf-8") as f:
self.current_index = int(f.read().strip())
except Exception:
self.current_index = 0
def _save_state(self):
with open(self.state_path, "w", encoding="utf-8") as f:
f.write(str(self.current_index))
def _log_lesson(self, text: str, score: float):
with open(self.log_path, "a", encoding="utf-8") as f:
f.write(f"[{datetime.utcnow().isoformat()}] {score:.2f} | {text.strip()}\n")
async def start_lessons(self, interval=10):
print("[LESSON] Starting lesson loop...")
while self.current_index < len(self.lessons):
line = self.lessons[self.current_index]
if len(line.split()) >= 3 and self._is_valid(line):
score = self.trainer.score_sentence(line)
if self.trainer.is_reinforceable(line) and score >= 2.0:
self.trainer.train_on_tokens_from_text(line)
self._log_lesson(line, score)
self.current_index += 1
self._save_state()
await asyncio.sleep(interval)
print("[LESSON] All lessons completed.")
def _is_valid(self, text: str) -> bool:
return all(c.isprintable() or c.isspace() for c in text)
def reset(self):
self.current_index = 0
self._save_state()
def add_lesson(self, text: str):
self.lessons.append(text)
with open(self.lesson_path, "a", encoding="utf-8") as f:
f.write(text.strip() + "\n")
def progress(self):
total = len(self.lessons)
return {
"current": self.current_index,
"total": total,
"percent": round(100 * self.current_index / total, 2) if total else 0.0
}

View File

@ -8,6 +8,7 @@ import dashboard
from tokenizer import Tokenizer from tokenizer import Tokenizer
from trainer import RubyTrainer from trainer import RubyTrainer
from reader import BookReader from reader import BookReader
from lesson import LessonModule
import logging import logging
# Setup logging # Setup logging
@ -41,6 +42,7 @@ class Ruby(discord.Client):
super().__init__(intents=intents) super().__init__(intents=intents)
self.tokenizer = Tokenizer() self.tokenizer = Tokenizer()
self.trainer = RubyTrainer(self.tokenizer) self.trainer = RubyTrainer(self.tokenizer)
self.lessons = LessonModule(self.trainer)
self.reader = BookReader(trainer=self.trainer, self.reader = BookReader(trainer=self.trainer,
book_path="books//wizard_of_oz.txt", # or whatever book you want book_path="books//wizard_of_oz.txt", # or whatever book you want
interval=180 # read every 3 minutes (adjust if needed) interval=180 # read every 3 minutes (adjust if needed)
@ -51,6 +53,7 @@ class Ruby(discord.Client):
os.makedirs("logs", exist_ok=True) os.makedirs("logs", exist_ok=True)
async def setup_hook(self): async def setup_hook(self):
self.loop.create_task(self.lessons.start_lessons())
self.loop.create_task(self.reader.start_reading()) self.loop.create_task(self.reader.start_reading())
self.loop.create_task(self.idle_dream_loop()) self.loop.create_task(self.idle_dream_loop())

View File

@ -79,38 +79,40 @@ class RubyTrainer:
print(f"[TRAIN] Tokens: {tokens} | Loss: {loss.item():.4f}") print(f"[TRAIN] Tokens: {tokens} | Loss: {loss.item():.4f}")
def generate_reply(self, prompt=None, max_length=20, temperature=1.3): def generate_reply(self, prompt=None, max_length=20, retry_limit=5):
self.model.eval() self.model.eval()
input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], device=self.device) for _ in range(retry_limit):
input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], device=self.device)
with torch.no_grad():
for _ in range(max_length):
max_id = self.model.token_embed.num_embeddings
input_ids = torch.clamp(input_ids, 0, max_id - 1)
output = self.model(input_ids)
logits = output[:, -1, :]
with torch.no_grad(): if input_ids.size(1) >= 2:
for _ in range(max_length): last_token = input_ids[0, -1].item()
max_id = self.model.token_embed.num_embeddings logits[0, last_token] *= 0.1
input_ids = torch.clamp(input_ids, 0, max_id - 1)
output = self.model(input_ids)
logits = output[:, -1, :]
# Apply repeat penalty next_token = torch.argmax(logits, dim=-1)
if input_ids.size(1) >= 2:
last_token = input_ids[0, -1].item()
logits[0, last_token] *= 0.1
# 🔥 Temperature sampling if next_token.item() >= self.model.token_embed.num_embeddings:
probs = F.softmax(logits / temperature, dim=-1) print("[ERROR] Token index OOB. Rebuilding model.")
next_token = torch.multinomial(probs, 1)[0].view(1) self.rebuild_model_if_needed()
return ""
if next_token.item() >= self.model.token_embed.num_embeddings: input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
print("[ERROR] Token index out of bounds. Rebuilding model...")
self.rebuild_model_if_needed()
return ""
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) if next_token.item() == self.tokenizer.vocab["<END>"]:
break
if next_token.item() == self.tokenizer.vocab["<END>"]: decoded = self.tokenizer.detokenize(input_ids.squeeze().tolist())
break decoded = decoded.replace("<START>", "").replace("<END>", "").strip()
output = self.tokenizer.detokenize(input_ids.squeeze().tolist()) if len(decoded.split()) >= 4 and self._has_sentence_structure(decoded):
return output.replace("<START>", "").replace("<END>", "").strip() return decoded
return ""
def self_rephrase(self, original: str, max_tokens=50, temperature=1.3): def self_rephrase(self, original: str, max_tokens=50, temperature=1.3):
@ -247,39 +249,50 @@ class RubyTrainer:
freqs = Counter(words) freqs = Counter(words)
# Reject if any token appears more than 5 times
if any(count > 5 for count in freqs.values()): if any(count > 5 for count in freqs.values()):
return False return False
# Reject if most common word is > 30% of sentence
if max(freqs.values()) / len(words) > 0.3: if max(freqs.values()) / len(words) > 0.3:
return False return False
# Reject if >3 tokens occur 3+ times
if sum(1 for c in freqs.values() if c >= 3) > 3: if sum(1 for c in freqs.values() if c >= 3) > 3:
return False return False
# Reject if "I am" occurs more than 25% of the time
if text.lower().count("i am") > len(text.split()) * 0.25: if text.lower().count("i am") > len(text.split()) * 0.25:
return False return False
# Reject if first word repeats 3 times ("you you you")
if words[:3].count(words[0]) == 3: if words[:3].count(words[0]) == 3:
return False return False
# 🧠 NEW: Reject if starts with common book phrases # 🧠 NEW: Reject if starts with known book structures
banned_starts = ("once upon", "chapter", "the end", "in which", "it was", "quick cried", "they are course") banned_starts = ("once upon", "chapter", "the end", "in which", "it was", "quick cried", "they are course")
lowered = text.lower() lowered = text.lower()
if any(lowered.startswith(phrase) for phrase in banned_starts): if any(lowered.startswith(phrase) for phrase in banned_starts):
return False return False
# 🧠 NEW: Reject if too many capitalized words in a row (e.g., names, places from a book) # 🧠 NEW: Capitalized name/term flooding check
cap_sequence = sum(1 for word in words if word.istitle()) cap_sequence = sum(1 for word in words if word.istitle())
if cap_sequence > 5 and cap_sequence / len(words) > 0.4: if cap_sequence > 5 and cap_sequence / len(words) > 0.4:
return False return False
# 🧠 NEW: Book-world saturation filter
blacklist = {"toto", "sparkle", "scarecrow", "munchkin", "wizard", "oz", "emerald", "wicked", "kansas"}
lower_words = [w.lower() for w in words]
if sum(1 for w in lower_words if w in blacklist) > 3:
return False
# 🧠 NEW: Diversity requirement
if len(set(lower_words)) < len(lower_words) * 0.5:
return False
return True return True
def _has_sentence_structure(self, text: str) -> bool:
# Simple heuristic: must contain a known subject and verb
has_subject = any(word.lower() in ["i", "you", "he", "she", "they", "we"] for word in text.split())
has_verb = any(text.lower().count(verb) for verb in ["am", "is", "are", "was", "were", "have", "feel", "see", "think", "said"])
return has_subject and has_verb
def score_sentence(self, sentence: str) -> float: def score_sentence(self, sentence: str) -> float:
words = sentence.strip().split() words = sentence.strip().split()
if not words: if not words: