From 3a77b5db3255d421177745b44869015d71dfc1b7 Mon Sep 17 00:00:00 2001 From: Dani Date: Sun, 27 Apr 2025 15:38:58 -0400 Subject: [PATCH] Reverted some changes due to the unicode cleaner being moved to the tokenizer. --- model/rehearsal.py | 18 +++++++++++------- model/tokenizer.py | 2 ++ model/trainer.py | 4 ---- reader/reader.py | 7 ++----- utils/unicleaner.py | 1 + 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/model/rehearsal.py b/model/rehearsal.py index 17eddc3..b18f6b9 100644 --- a/model/rehearsal.py +++ b/model/rehearsal.py @@ -2,25 +2,29 @@ import torch from model.brain import model, tokenizer, DEVICE from model.trainer import train_on_message from model.dynamic_expand import expand_model_if_needed -from utils.unicleaner import clean_unicode def simulate_conversation(): expand_model_if_needed() model.eval() - seed = torch.randint(0, tokenizer.next_id, (1, 5), device=DEVICE) - seed = seed[:, -128:] + + max_token_id = model.head.out_features - 1 + if max_token_id < 1: + return # Safeguard if model is still too small + + seed = torch.randint(0, max_token_id + 1, (1, 5), device=DEVICE) + seed = seed[:, -128:] # Clamp sequence length + output = model(seed) preds = torch.argmax(output, dim=-1).squeeze().tolist() if isinstance(preds, int): preds = [preds] + # 🛡 Clamp predictions too + preds = [min(max(p, 0), max_token_id) for p in preds] + text = tokenizer.detokenize(preds) - - # 🧹 Clean the generated text too - text = clean_unicode(text) - if text and len(text.split()) >= 3: train_on_message(text) diff --git a/model/tokenizer.py b/model/tokenizer.py index 0765d68..6359a79 100644 --- a/model/tokenizer.py +++ b/model/tokenizer.py @@ -1,6 +1,7 @@ import re import os import json +from utils.unicleaner import clean_unicode VOCAB_PATH = "data/memory/vocab.json" @@ -24,6 +25,7 @@ class Tokenizer: self.next_id = 4 def tokenize(self, text): + text = clean_unicode(text) # 🚨 Always clean incoming text words = re.findall(r"\b\w+\b", text.lower()) tokens = [] for word in words: diff --git a/model/trainer.py b/model/trainer.py index 71ca696..2b3df70 100644 --- a/model/trainer.py +++ b/model/trainer.py @@ -4,7 +4,6 @@ from model.dynamic_expand import expand_model_if_needed, _last_expansion_time, g from model.brain_state import model, tokenizer, DEVICE, loss_fn from model.brainmap import update_brainmap from context.context import add_to_context, get_recent_context -from utils.unicleaner import clean_unicode LOSS_FILE = "data/logs/loss.log" VOCAB_GROWTH_FILE = "data/logs/vocab_growth.log" @@ -36,9 +35,6 @@ def train_on_message(text: str, source: str = "user"): try: model.train() - # 🧹 Clean up the incoming text - text = clean_unicode(text) - context_texts = get_recent_context(10) augmented_text = " " + " ".join(context_texts + [text]) + " " diff --git a/reader/reader.py b/reader/reader.py index 89ebe44..9783ea0 100644 --- a/reader/reader.py +++ b/reader/reader.py @@ -3,7 +3,6 @@ import asyncio from model.trainer import train_on_message from model.scheduler import set_next_action from reader.filter import is_valid_line -from utils.unicleaner import clean_unicode import json BOOK_DIR = "data/books" @@ -49,8 +48,7 @@ async def read_books_forever(): if not line: if len(paragraph) > PARAGRAPH_MIN_LENGTH: - cleaned_paragraph = clean_unicode(paragraph.strip()) - train_on_message(cleaned_paragraph, source="book") + train_on_message(paragraph.strip(), source="book") paragraph = "" await asyncio.sleep(READ_DELAY) set_next_action(READ_DELAY, "Reading") @@ -62,7 +60,6 @@ async def read_books_forever(): # train last paragraph if any if paragraph and len(paragraph) > PARAGRAPH_MIN_LENGTH: - cleaned_paragraph = clean_unicode(paragraph.strip()) - train_on_message(cleaned_paragraph, source="book") + train_on_message(paragraph.strip(), source="book") await asyncio.sleep(READ_DELAY) set_next_action(READ_DELAY, "Reading") diff --git a/utils/unicleaner.py b/utils/unicleaner.py index 1fab524..f4002ce 100644 --- a/utils/unicleaner.py +++ b/utils/unicleaner.py @@ -17,6 +17,7 @@ RE_DASHES = { '\u2014': '-', # Em dash } + def clean_unicode(text: str) -> str: # 1. Replace fancy quotes for bad, good in RE_QUOTES.items():