Added lesson planning

This commit is contained in:
Dani 2025-04-24 12:54:30 -04:00
parent 7823ce1d5e
commit c73091c8c5
4 changed files with 120 additions and 32 deletions

2
.gitignore vendored
View File

@ -174,3 +174,5 @@ cython_debug/
/.vscode/launch.json
/books
/readstate.txt
/lessonstate.txt
/data/lessons.txt

70
lesson.py Normal file
View File

@ -0,0 +1,70 @@
import os
import asyncio
from datetime import datetime
class LessonModule:
def __init__(self, trainer, lesson_path="data/lessons.txt", log_path="logs/lesson.log", state_path="lessonstate.txt"):
self.trainer = trainer
self.lesson_path = lesson_path
self.log_path = log_path
self.state_path = state_path
self.current_index = 0
self.lessons = []
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
if os.path.exists(self.lesson_path):
with open(self.lesson_path, "r", encoding="utf-8") as f:
self.lessons = [line.strip() for line in f if line.strip()]
if os.path.exists(self.state_path):
try:
with open(self.state_path, "r", encoding="utf-8") as f:
self.current_index = int(f.read().strip())
except Exception:
self.current_index = 0
def _save_state(self):
with open(self.state_path, "w", encoding="utf-8") as f:
f.write(str(self.current_index))
def _log_lesson(self, text: str, score: float):
with open(self.log_path, "a", encoding="utf-8") as f:
f.write(f"[{datetime.utcnow().isoformat()}] {score:.2f} | {text.strip()}\n")
async def start_lessons(self, interval=10):
print("[LESSON] Starting lesson loop...")
while self.current_index < len(self.lessons):
line = self.lessons[self.current_index]
if len(line.split()) >= 3 and self._is_valid(line):
score = self.trainer.score_sentence(line)
if self.trainer.is_reinforceable(line) and score >= 2.0:
self.trainer.train_on_tokens_from_text(line)
self._log_lesson(line, score)
self.current_index += 1
self._save_state()
await asyncio.sleep(interval)
print("[LESSON] All lessons completed.")
def _is_valid(self, text: str) -> bool:
return all(c.isprintable() or c.isspace() for c in text)
def reset(self):
self.current_index = 0
self._save_state()
def add_lesson(self, text: str):
self.lessons.append(text)
with open(self.lesson_path, "a", encoding="utf-8") as f:
f.write(text.strip() + "\n")
def progress(self):
total = len(self.lessons)
return {
"current": self.current_index,
"total": total,
"percent": round(100 * self.current_index / total, 2) if total else 0.0
}

View File

@ -8,6 +8,7 @@ import dashboard
from tokenizer import Tokenizer
from trainer import RubyTrainer
from reader import BookReader
from lesson import LessonModule
import logging
# Setup logging
@ -41,6 +42,7 @@ class Ruby(discord.Client):
super().__init__(intents=intents)
self.tokenizer = Tokenizer()
self.trainer = RubyTrainer(self.tokenizer)
self.lessons = LessonModule(self.trainer)
self.reader = BookReader(trainer=self.trainer,
book_path="books//wizard_of_oz.txt", # or whatever book you want
interval=180 # read every 3 minutes (adjust if needed)
@ -51,6 +53,7 @@ class Ruby(discord.Client):
os.makedirs("logs", exist_ok=True)
async def setup_hook(self):
self.loop.create_task(self.lessons.start_lessons())
self.loop.create_task(self.reader.start_reading())
self.loop.create_task(self.idle_dream_loop())

View File

@ -79,38 +79,40 @@ class RubyTrainer:
print(f"[TRAIN] Tokens: {tokens} | Loss: {loss.item():.4f}")
def generate_reply(self, prompt=None, max_length=20, temperature=1.3):
def generate_reply(self, prompt=None, max_length=20, retry_limit=5):
self.model.eval()
input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], device=self.device)
for _ in range(retry_limit):
input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], device=self.device)
with torch.no_grad():
for _ in range(max_length):
max_id = self.model.token_embed.num_embeddings
input_ids = torch.clamp(input_ids, 0, max_id - 1)
output = self.model(input_ids)
logits = output[:, -1, :]
with torch.no_grad():
for _ in range(max_length):
max_id = self.model.token_embed.num_embeddings
input_ids = torch.clamp(input_ids, 0, max_id - 1)
output = self.model(input_ids)
logits = output[:, -1, :]
if input_ids.size(1) >= 2:
last_token = input_ids[0, -1].item()
logits[0, last_token] *= 0.1
# Apply repeat penalty
if input_ids.size(1) >= 2:
last_token = input_ids[0, -1].item()
logits[0, last_token] *= 0.1
next_token = torch.argmax(logits, dim=-1)
# 🔥 Temperature sampling
probs = F.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, 1)[0].view(1)
if next_token.item() >= self.model.token_embed.num_embeddings:
print("[ERROR] Token index OOB. Rebuilding model.")
self.rebuild_model_if_needed()
return ""
if next_token.item() >= self.model.token_embed.num_embeddings:
print("[ERROR] Token index out of bounds. Rebuilding model...")
self.rebuild_model_if_needed()
return ""
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
if next_token.item() == self.tokenizer.vocab["<END>"]:
break
if next_token.item() == self.tokenizer.vocab["<END>"]:
break
decoded = self.tokenizer.detokenize(input_ids.squeeze().tolist())
decoded = decoded.replace("<START>", "").replace("<END>", "").strip()
output = self.tokenizer.detokenize(input_ids.squeeze().tolist())
return output.replace("<START>", "").replace("<END>", "").strip()
if len(decoded.split()) >= 4 and self._has_sentence_structure(decoded):
return decoded
return ""
def self_rephrase(self, original: str, max_tokens=50, temperature=1.3):
@ -247,39 +249,50 @@ class RubyTrainer:
freqs = Counter(words)
# Reject if any token appears more than 5 times
if any(count > 5 for count in freqs.values()):
return False
# Reject if most common word is > 30% of sentence
if max(freqs.values()) / len(words) > 0.3:
return False
# Reject if >3 tokens occur 3+ times
if sum(1 for c in freqs.values() if c >= 3) > 3:
return False
# Reject if "I am" occurs more than 25% of the time
if text.lower().count("i am") > len(text.split()) * 0.25:
return False
# Reject if first word repeats 3 times ("you you you")
if words[:3].count(words[0]) == 3:
return False
# 🧠 NEW: Reject if starts with common book phrases
# 🧠 NEW: Reject if starts with known book structures
banned_starts = ("once upon", "chapter", "the end", "in which", "it was", "quick cried", "they are course")
lowered = text.lower()
if any(lowered.startswith(phrase) for phrase in banned_starts):
return False
# 🧠 NEW: Reject if too many capitalized words in a row (e.g., names, places from a book)
# 🧠 NEW: Capitalized name/term flooding check
cap_sequence = sum(1 for word in words if word.istitle())
if cap_sequence > 5 and cap_sequence / len(words) > 0.4:
return False
# 🧠 NEW: Book-world saturation filter
blacklist = {"toto", "sparkle", "scarecrow", "munchkin", "wizard", "oz", "emerald", "wicked", "kansas"}
lower_words = [w.lower() for w in words]
if sum(1 for w in lower_words if w in blacklist) > 3:
return False
# 🧠 NEW: Diversity requirement
if len(set(lower_words)) < len(lower_words) * 0.5:
return False
return True
def _has_sentence_structure(self, text: str) -> bool:
# Simple heuristic: must contain a known subject and verb
has_subject = any(word.lower() in ["i", "you", "he", "she", "they", "we"] for word in text.split())
has_verb = any(text.lower().count(verb) for verb in ["am", "is", "are", "was", "were", "have", "feel", "see", "think", "said"])
return has_subject and has_verb
def score_sentence(self, sentence: str) -> float:
words = sentence.strip().split()
if not words: