From c73091c8c5f207719938ec3a5e0ec47b3a7a57b5 Mon Sep 17 00:00:00 2001 From: Dani Date: Thu, 24 Apr 2025 12:54:30 -0400 Subject: [PATCH] Added lesson planning --- .gitignore | 4 ++- lesson.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 3 +++ trainer.py | 75 ++++++++++++++++++++++++++++++++---------------------- 4 files changed, 120 insertions(+), 32 deletions(-) create mode 100644 lesson.py diff --git a/.gitignore b/.gitignore index c912757..3b123c8 100644 --- a/.gitignore +++ b/.gitignore @@ -173,4 +173,6 @@ cython_debug/ /logs/best_dream.txt /.vscode/launch.json /books -/readstate.txt \ No newline at end of file +/readstate.txt +/lessonstate.txt +/data/lessons.txt \ No newline at end of file diff --git a/lesson.py b/lesson.py new file mode 100644 index 0000000..0c04c67 --- /dev/null +++ b/lesson.py @@ -0,0 +1,70 @@ +import os +import asyncio +from datetime import datetime + + +class LessonModule: + def __init__(self, trainer, lesson_path="data/lessons.txt", log_path="logs/lesson.log", state_path="lessonstate.txt"): + self.trainer = trainer + self.lesson_path = lesson_path + self.log_path = log_path + self.state_path = state_path + self.current_index = 0 + self.lessons = [] + + os.makedirs(os.path.dirname(self.log_path), exist_ok=True) + + if os.path.exists(self.lesson_path): + with open(self.lesson_path, "r", encoding="utf-8") as f: + self.lessons = [line.strip() for line in f if line.strip()] + + if os.path.exists(self.state_path): + try: + with open(self.state_path, "r", encoding="utf-8") as f: + self.current_index = int(f.read().strip()) + except Exception: + self.current_index = 0 + + def _save_state(self): + with open(self.state_path, "w", encoding="utf-8") as f: + f.write(str(self.current_index)) + + def _log_lesson(self, text: str, score: float): + with open(self.log_path, "a", encoding="utf-8") as f: + f.write(f"[{datetime.utcnow().isoformat()}] {score:.2f} | {text.strip()}\n") + + async def start_lessons(self, interval=10): + print("[LESSON] Starting lesson loop...") + while self.current_index < len(self.lessons): + line = self.lessons[self.current_index] + if len(line.split()) >= 3 and self._is_valid(line): + score = self.trainer.score_sentence(line) + if self.trainer.is_reinforceable(line) and score >= 2.0: + self.trainer.train_on_tokens_from_text(line) + self._log_lesson(line, score) + + self.current_index += 1 + self._save_state() + await asyncio.sleep(interval) + + print("[LESSON] All lessons completed.") + + def _is_valid(self, text: str) -> bool: + return all(c.isprintable() or c.isspace() for c in text) + + def reset(self): + self.current_index = 0 + self._save_state() + + def add_lesson(self, text: str): + self.lessons.append(text) + with open(self.lesson_path, "a", encoding="utf-8") as f: + f.write(text.strip() + "\n") + + def progress(self): + total = len(self.lessons) + return { + "current": self.current_index, + "total": total, + "percent": round(100 * self.current_index / total, 2) if total else 0.0 + } diff --git a/main.py b/main.py index 68895f5..9a340f5 100644 --- a/main.py +++ b/main.py @@ -8,6 +8,7 @@ import dashboard from tokenizer import Tokenizer from trainer import RubyTrainer from reader import BookReader +from lesson import LessonModule import logging # Setup logging @@ -41,6 +42,7 @@ class Ruby(discord.Client): super().__init__(intents=intents) self.tokenizer = Tokenizer() self.trainer = RubyTrainer(self.tokenizer) + self.lessons = LessonModule(self.trainer) self.reader = BookReader(trainer=self.trainer, book_path="books//wizard_of_oz.txt", # or whatever book you want interval=180 # read every 3 minutes (adjust if needed) @@ -51,6 +53,7 @@ class Ruby(discord.Client): os.makedirs("logs", exist_ok=True) async def setup_hook(self): + self.loop.create_task(self.lessons.start_lessons()) self.loop.create_task(self.reader.start_reading()) self.loop.create_task(self.idle_dream_loop()) diff --git a/trainer.py b/trainer.py index cdc8dda..55e7ea7 100644 --- a/trainer.py +++ b/trainer.py @@ -79,38 +79,40 @@ class RubyTrainer: print(f"[TRAIN] Tokens: {tokens} | Loss: {loss.item():.4f}") - def generate_reply(self, prompt=None, max_length=20, temperature=1.3): + def generate_reply(self, prompt=None, max_length=20, retry_limit=5): self.model.eval() - input_ids = torch.tensor([[self.tokenizer.vocab[""]]], device=self.device) + for _ in range(retry_limit): + input_ids = torch.tensor([[self.tokenizer.vocab[""]]], device=self.device) + with torch.no_grad(): + for _ in range(max_length): + max_id = self.model.token_embed.num_embeddings + input_ids = torch.clamp(input_ids, 0, max_id - 1) + output = self.model(input_ids) + logits = output[:, -1, :] - with torch.no_grad(): - for _ in range(max_length): - max_id = self.model.token_embed.num_embeddings - input_ids = torch.clamp(input_ids, 0, max_id - 1) - output = self.model(input_ids) - logits = output[:, -1, :] + if input_ids.size(1) >= 2: + last_token = input_ids[0, -1].item() + logits[0, last_token] *= 0.1 - # Apply repeat penalty - if input_ids.size(1) >= 2: - last_token = input_ids[0, -1].item() - logits[0, last_token] *= 0.1 + next_token = torch.argmax(logits, dim=-1) - # 🔥 Temperature sampling - probs = F.softmax(logits / temperature, dim=-1) - next_token = torch.multinomial(probs, 1)[0].view(1) + if next_token.item() >= self.model.token_embed.num_embeddings: + print("[ERROR] Token index OOB. Rebuilding model.") + self.rebuild_model_if_needed() + return "" - if next_token.item() >= self.model.token_embed.num_embeddings: - print("[ERROR] Token index out of bounds. Rebuilding model...") - self.rebuild_model_if_needed() - return "" + input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) - input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) + if next_token.item() == self.tokenizer.vocab[""]: + break - if next_token.item() == self.tokenizer.vocab[""]: - break + decoded = self.tokenizer.detokenize(input_ids.squeeze().tolist()) + decoded = decoded.replace("", "").replace("", "").strip() - output = self.tokenizer.detokenize(input_ids.squeeze().tolist()) - return output.replace("", "").replace("", "").strip() + if len(decoded.split()) >= 4 and self._has_sentence_structure(decoded): + return decoded + + return "" def self_rephrase(self, original: str, max_tokens=50, temperature=1.3): @@ -247,39 +249,50 @@ class RubyTrainer: freqs = Counter(words) - # Reject if any token appears more than 5 times if any(count > 5 for count in freqs.values()): return False - # Reject if most common word is > 30% of sentence if max(freqs.values()) / len(words) > 0.3: return False - # Reject if >3 tokens occur 3+ times if sum(1 for c in freqs.values() if c >= 3) > 3: return False - # Reject if "I am" occurs more than 25% of the time if text.lower().count("i am") > len(text.split()) * 0.25: return False - # Reject if first word repeats 3 times ("you you you") if words[:3].count(words[0]) == 3: return False - # 🧠 NEW: Reject if starts with common book phrases + # 🧠 NEW: Reject if starts with known book structures banned_starts = ("once upon", "chapter", "the end", "in which", "it was", "quick cried", "they are course") lowered = text.lower() if any(lowered.startswith(phrase) for phrase in banned_starts): return False - # 🧠 NEW: Reject if too many capitalized words in a row (e.g., names, places from a book) + # 🧠 NEW: Capitalized name/term flooding check cap_sequence = sum(1 for word in words if word.istitle()) if cap_sequence > 5 and cap_sequence / len(words) > 0.4: return False + # 🧠 NEW: Book-world saturation filter + blacklist = {"toto", "sparkle", "scarecrow", "munchkin", "wizard", "oz", "emerald", "wicked", "kansas"} + lower_words = [w.lower() for w in words] + if sum(1 for w in lower_words if w in blacklist) > 3: + return False + + # 🧠 NEW: Diversity requirement + if len(set(lower_words)) < len(lower_words) * 0.5: + return False + return True + def _has_sentence_structure(self, text: str) -> bool: + # Simple heuristic: must contain a known subject and verb + has_subject = any(word.lower() in ["i", "you", "he", "she", "they", "we"] for word in text.split()) + has_verb = any(text.lower().count(verb) for verb in ["am", "is", "are", "was", "were", "have", "feel", "see", "think", "said"]) + return has_subject and has_verb + def score_sentence(self, sentence: str) -> float: words = sentence.strip().split() if not words: