diff --git a/.gitignore b/.gitignore index c8986bd..04d8c76 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,5 @@ cython_debug/ #.idea/ .vscode/launch.json +/data/books/alice_in_wonderland.txt +/data/books/wizard_of_oz.txt \ No newline at end of file diff --git a/main.py b/main.py index 67cf9ac..f559b05 100644 --- a/main.py +++ b/main.py @@ -35,5 +35,8 @@ async def on_message(message): # Launch Flask in background threading.Thread(target=run_dashboard, daemon=True).start() +loop = asyncio.get_event_loop() +loop.create_task(read_books_forever()) # Book reader task + # Launch Discord bot (blocking) client.run(TOKEN) diff --git a/model/brain.py b/model/brain.py index 8feb07f..af78cd0 100644 --- a/model/brain.py +++ b/model/brain.py @@ -55,7 +55,7 @@ def daydream(): seed = torch.tensor([random.randint(0, tokenizer.next_id - 1)], device=DEVICE).unsqueeze(0) dream = [] - for _ in range(12): # generate 12-word thought + for _ in range(12): out = model(seed) logits = out[:, -1, :] probs = F.softmax(logits, dim=-1) @@ -68,6 +68,7 @@ def daydream(): if score > 0.3: save_dream(sentence, score) + train_on_message(sentence) recent_dreams.append((score, sentence)) if len(recent_dreams) > 10: recent_dreams.pop(0) diff --git a/reader/filter.py b/reader/filter.py index e69de29..b5eee4d 100644 --- a/reader/filter.py +++ b/reader/filter.py @@ -0,0 +1,12 @@ +import re + + +def is_valid_line(text: str) -> bool: + text = text.strip() + if len(text) < 10: + return False + if not re.search(r"[a-zA-Z]", text): + return False + if any(c in text for c in ["�", "\ufffd"]): + return False + return True diff --git a/reader/reader.py b/reader/reader.py index e69de29..8b12448 100644 --- a/reader/reader.py +++ b/reader/reader.py @@ -0,0 +1,48 @@ +import os +import asyncio +from model.train import train_on_message +from reader.filter import is_valid_line + +BOOK_DIR = "data/books" +PROGRESS_FILE = "data/memory/book_progress.json" +READ_DELAY = 10 # seconds between lines + + +def get_books(): + return [f for f in os.listdir(BOOK_DIR) if f.endswith(".txt")] + + +def load_progress(): + if os.path.exists(PROGRESS_FILE): + with open(PROGRESS_FILE, "r", encoding="utf-8") as f: + return json.load(f) + return {} + + +def save_progress(prog): + with open(PROGRESS_FILE, "w", encoding="utf-8") as f: + json.dump(prog, f, indent=2) + + +async def read_books_forever(): + books = get_books() + progress = load_progress() + while True: + for book in books: + path = os.path.join(BOOK_DIR, book) + if not os.path.exists(path): + continue + + with open(path, "r", encoding="utf-8") as f: + lines = f.readlines() + + idx = progress.get(book, 0) + while idx < len(lines): + line = lines[idx].strip() + idx += 1 + progress[book] = idx + save_progress(progress) + + if is_valid_line(line): + train_on_message(line) + await asyncio.sleep(READ_DELAY)