Added lesson planning
This commit is contained in:
parent
7823ce1d5e
commit
c73091c8c5
2
.gitignore
vendored
2
.gitignore
vendored
@ -174,3 +174,5 @@ cython_debug/
|
||||
/.vscode/launch.json
|
||||
/books
|
||||
/readstate.txt
|
||||
/lessonstate.txt
|
||||
/data/lessons.txt
|
70
lesson.py
Normal file
70
lesson.py
Normal file
@ -0,0 +1,70 @@
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class LessonModule:
|
||||
def __init__(self, trainer, lesson_path="data/lessons.txt", log_path="logs/lesson.log", state_path="lessonstate.txt"):
|
||||
self.trainer = trainer
|
||||
self.lesson_path = lesson_path
|
||||
self.log_path = log_path
|
||||
self.state_path = state_path
|
||||
self.current_index = 0
|
||||
self.lessons = []
|
||||
|
||||
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
|
||||
|
||||
if os.path.exists(self.lesson_path):
|
||||
with open(self.lesson_path, "r", encoding="utf-8") as f:
|
||||
self.lessons = [line.strip() for line in f if line.strip()]
|
||||
|
||||
if os.path.exists(self.state_path):
|
||||
try:
|
||||
with open(self.state_path, "r", encoding="utf-8") as f:
|
||||
self.current_index = int(f.read().strip())
|
||||
except Exception:
|
||||
self.current_index = 0
|
||||
|
||||
def _save_state(self):
|
||||
with open(self.state_path, "w", encoding="utf-8") as f:
|
||||
f.write(str(self.current_index))
|
||||
|
||||
def _log_lesson(self, text: str, score: float):
|
||||
with open(self.log_path, "a", encoding="utf-8") as f:
|
||||
f.write(f"[{datetime.utcnow().isoformat()}] {score:.2f} | {text.strip()}\n")
|
||||
|
||||
async def start_lessons(self, interval=10):
|
||||
print("[LESSON] Starting lesson loop...")
|
||||
while self.current_index < len(self.lessons):
|
||||
line = self.lessons[self.current_index]
|
||||
if len(line.split()) >= 3 and self._is_valid(line):
|
||||
score = self.trainer.score_sentence(line)
|
||||
if self.trainer.is_reinforceable(line) and score >= 2.0:
|
||||
self.trainer.train_on_tokens_from_text(line)
|
||||
self._log_lesson(line, score)
|
||||
|
||||
self.current_index += 1
|
||||
self._save_state()
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
print("[LESSON] All lessons completed.")
|
||||
|
||||
def _is_valid(self, text: str) -> bool:
|
||||
return all(c.isprintable() or c.isspace() for c in text)
|
||||
|
||||
def reset(self):
|
||||
self.current_index = 0
|
||||
self._save_state()
|
||||
|
||||
def add_lesson(self, text: str):
|
||||
self.lessons.append(text)
|
||||
with open(self.lesson_path, "a", encoding="utf-8") as f:
|
||||
f.write(text.strip() + "\n")
|
||||
|
||||
def progress(self):
|
||||
total = len(self.lessons)
|
||||
return {
|
||||
"current": self.current_index,
|
||||
"total": total,
|
||||
"percent": round(100 * self.current_index / total, 2) if total else 0.0
|
||||
}
|
3
main.py
3
main.py
@ -8,6 +8,7 @@ import dashboard
|
||||
from tokenizer import Tokenizer
|
||||
from trainer import RubyTrainer
|
||||
from reader import BookReader
|
||||
from lesson import LessonModule
|
||||
import logging
|
||||
|
||||
# Setup logging
|
||||
@ -41,6 +42,7 @@ class Ruby(discord.Client):
|
||||
super().__init__(intents=intents)
|
||||
self.tokenizer = Tokenizer()
|
||||
self.trainer = RubyTrainer(self.tokenizer)
|
||||
self.lessons = LessonModule(self.trainer)
|
||||
self.reader = BookReader(trainer=self.trainer,
|
||||
book_path="books//wizard_of_oz.txt", # or whatever book you want
|
||||
interval=180 # read every 3 minutes (adjust if needed)
|
||||
@ -51,6 +53,7 @@ class Ruby(discord.Client):
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
|
||||
async def setup_hook(self):
|
||||
self.loop.create_task(self.lessons.start_lessons())
|
||||
self.loop.create_task(self.reader.start_reading())
|
||||
self.loop.create_task(self.idle_dream_loop())
|
||||
|
||||
|
75
trainer.py
75
trainer.py
@ -79,38 +79,40 @@ class RubyTrainer:
|
||||
|
||||
print(f"[TRAIN] Tokens: {tokens} | Loss: {loss.item():.4f}")
|
||||
|
||||
def generate_reply(self, prompt=None, max_length=20, temperature=1.3):
|
||||
def generate_reply(self, prompt=None, max_length=20, retry_limit=5):
|
||||
self.model.eval()
|
||||
input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], device=self.device)
|
||||
for _ in range(retry_limit):
|
||||
input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], device=self.device)
|
||||
with torch.no_grad():
|
||||
for _ in range(max_length):
|
||||
max_id = self.model.token_embed.num_embeddings
|
||||
input_ids = torch.clamp(input_ids, 0, max_id - 1)
|
||||
output = self.model(input_ids)
|
||||
logits = output[:, -1, :]
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(max_length):
|
||||
max_id = self.model.token_embed.num_embeddings
|
||||
input_ids = torch.clamp(input_ids, 0, max_id - 1)
|
||||
output = self.model(input_ids)
|
||||
logits = output[:, -1, :]
|
||||
if input_ids.size(1) >= 2:
|
||||
last_token = input_ids[0, -1].item()
|
||||
logits[0, last_token] *= 0.1
|
||||
|
||||
# Apply repeat penalty
|
||||
if input_ids.size(1) >= 2:
|
||||
last_token = input_ids[0, -1].item()
|
||||
logits[0, last_token] *= 0.1
|
||||
next_token = torch.argmax(logits, dim=-1)
|
||||
|
||||
# 🔥 Temperature sampling
|
||||
probs = F.softmax(logits / temperature, dim=-1)
|
||||
next_token = torch.multinomial(probs, 1)[0].view(1)
|
||||
if next_token.item() >= self.model.token_embed.num_embeddings:
|
||||
print("[ERROR] Token index OOB. Rebuilding model.")
|
||||
self.rebuild_model_if_needed()
|
||||
return ""
|
||||
|
||||
if next_token.item() >= self.model.token_embed.num_embeddings:
|
||||
print("[ERROR] Token index out of bounds. Rebuilding model...")
|
||||
self.rebuild_model_if_needed()
|
||||
return ""
|
||||
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
||||
|
||||
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
||||
if next_token.item() == self.tokenizer.vocab["<END>"]:
|
||||
break
|
||||
|
||||
if next_token.item() == self.tokenizer.vocab["<END>"]:
|
||||
break
|
||||
decoded = self.tokenizer.detokenize(input_ids.squeeze().tolist())
|
||||
decoded = decoded.replace("<START>", "").replace("<END>", "").strip()
|
||||
|
||||
output = self.tokenizer.detokenize(input_ids.squeeze().tolist())
|
||||
return output.replace("<START>", "").replace("<END>", "").strip()
|
||||
if len(decoded.split()) >= 4 and self._has_sentence_structure(decoded):
|
||||
return decoded
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def self_rephrase(self, original: str, max_tokens=50, temperature=1.3):
|
||||
@ -247,39 +249,50 @@ class RubyTrainer:
|
||||
|
||||
freqs = Counter(words)
|
||||
|
||||
# Reject if any token appears more than 5 times
|
||||
if any(count > 5 for count in freqs.values()):
|
||||
return False
|
||||
|
||||
# Reject if most common word is > 30% of sentence
|
||||
if max(freqs.values()) / len(words) > 0.3:
|
||||
return False
|
||||
|
||||
# Reject if >3 tokens occur 3+ times
|
||||
if sum(1 for c in freqs.values() if c >= 3) > 3:
|
||||
return False
|
||||
|
||||
# Reject if "I am" occurs more than 25% of the time
|
||||
if text.lower().count("i am") > len(text.split()) * 0.25:
|
||||
return False
|
||||
|
||||
# Reject if first word repeats 3 times ("you you you")
|
||||
if words[:3].count(words[0]) == 3:
|
||||
return False
|
||||
|
||||
# 🧠 NEW: Reject if starts with common book phrases
|
||||
# 🧠 NEW: Reject if starts with known book structures
|
||||
banned_starts = ("once upon", "chapter", "the end", "in which", "it was", "quick cried", "they are course")
|
||||
lowered = text.lower()
|
||||
if any(lowered.startswith(phrase) for phrase in banned_starts):
|
||||
return False
|
||||
|
||||
# 🧠 NEW: Reject if too many capitalized words in a row (e.g., names, places from a book)
|
||||
# 🧠 NEW: Capitalized name/term flooding check
|
||||
cap_sequence = sum(1 for word in words if word.istitle())
|
||||
if cap_sequence > 5 and cap_sequence / len(words) > 0.4:
|
||||
return False
|
||||
|
||||
# 🧠 NEW: Book-world saturation filter
|
||||
blacklist = {"toto", "sparkle", "scarecrow", "munchkin", "wizard", "oz", "emerald", "wicked", "kansas"}
|
||||
lower_words = [w.lower() for w in words]
|
||||
if sum(1 for w in lower_words if w in blacklist) > 3:
|
||||
return False
|
||||
|
||||
# 🧠 NEW: Diversity requirement
|
||||
if len(set(lower_words)) < len(lower_words) * 0.5:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _has_sentence_structure(self, text: str) -> bool:
|
||||
# Simple heuristic: must contain a known subject and verb
|
||||
has_subject = any(word.lower() in ["i", "you", "he", "she", "they", "we"] for word in text.split())
|
||||
has_verb = any(text.lower().count(verb) for verb in ["am", "is", "are", "was", "were", "have", "feel", "see", "think", "said"])
|
||||
return has_subject and has_verb
|
||||
|
||||
def score_sentence(self, sentence: str) -> float:
|
||||
words = sentence.strip().split()
|
||||
if not words:
|
||||
|
Loading…
x
Reference in New Issue
Block a user