diff --git a/.gitignore b/.gitignore index 7299dec..0df970e 100644 --- a/.gitignore +++ b/.gitignore @@ -196,3 +196,5 @@ cython_debug/ checkpoints/nora_step_*.pt data/books *.json +data/conversational/cornell_movie_dialogs.txt +.gitignore diff --git a/config.py b/config.py index 0ad435c..41e13ed 100644 --- a/config.py +++ b/config.py @@ -67,7 +67,7 @@ def get_config(): parser.add_argument( "--batch_size", type=int, - default=32, + default=512, help="Batch size per training step.", ) parser.add_argument( diff --git a/data/conversational/persona_chat.txt b/data/conversational/persona_chat.txt new file mode 100644 index 0000000..e69de29 diff --git a/data_loader.py b/data_loader.py index 4a9b5ae..7262868 100644 --- a/data_loader.py +++ b/data_loader.py @@ -21,15 +21,25 @@ class TextDataset(Dataset): self.seq_length = seq_length self.tokenizer = tokenizer - # Read and concatenate all text files into one long string + # Read and concatenate all .txt files under two folders: + # - data/books/ + # - data/conversational/ texts = [] - for root, _, files in os.walk(data_dir): - for fname in files: - if not fname.lower().endswith(".txt"): - continue - path = os.path.join(root, fname) - with open(path, "r", encoding="utf-8", errors="ignore") as f: - texts.append(f.read()) + # If data_dir is a single path, we still look for a sibling "conversational" folder + conversational_dir = os.path.join(os.path.dirname(data_dir), "conversational") + # Gather all folders to walk + dirs_to_walk = [data_dir] + if os.path.isdir(conversational_dir): + dirs_to_walk.append(conversational_dir) + + for d in dirs_to_walk: + for root, _, files in os.walk(d): + for fname in files: + if not fname.lower().endswith(".txt"): + continue + path = os.path.join(root, fname) + with open(path, "r", encoding="utf-8", errors="ignore") as f: + texts.append(f.read()) full_text = "\n".join(texts) token_ids = self.tokenizer.encode(full_text) diff --git a/data_prep.py b/data_prep.py new file mode 100644 index 0000000..4c81aa3 --- /dev/null +++ b/data_prep.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +""" +data_prep.py + +1) Attempts to download Cornell Movie-Dialogs via ConvoKit (key: "movie-corpus"). + - If ConvoKit/Unsloth fails, falls back to manual ZIP download/extraction. + +2) Attempts to download PersonaChat via Hugging Face Datasets: + - First tries "persona_chat" (older key). + - If that fails, tries "conv_ai_2" (alias). + - Catches any exception to skip gracefully. + +3) Writes each utterance to: + data/conversational/cornell_movie_dialogs.txt + data/conversational/persona_chat.txt + +After running, you’ll have: + data/ + ├── books/ (your original Gutenberg .txt files) + └── conversational/ + ├── cornell_movie_dialogs.txt + └── persona_chat.txt + +Then retrain or fine-tune Nora on data/books/ + data/conversational/. +""" + +import os +import sys +import zipfile +import tempfile +import urllib.request +from pathlib import Path + +# === 1) Attempt to import ConvoKit for Cornell Movie-Dialogs === +USE_CONVOKIT = True +try: + from convokit import Corpus, download as convokit_download +except ImportError: + USE_CONVOKIT = False + +# === 2) Attempt to import Hugging Face Datasets === +HAS_DATASETS = True +try: + from datasets import load_dataset +except ImportError: + HAS_DATASETS = False + +# Directory for conversational data +CONV_DIR = Path("data/conversational") +CONV_DIR.mkdir(parents=True, exist_ok=True) + +# Official ZIP URL (fallback) for Cornell Movie-Dialogs +CORNELL_ZIP_URL = "https://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip" + + +def install_package(pkg_name: str): + """ + Installs a Python package using the same Python interpreter, + wrapping the path in quotes to handle spaces. + """ + python_executable = sys.executable + command = f"\"{python_executable}\" -m pip install {pkg_name}" + print(f"[data_prep] Installing package: {pkg_name}") + os.system(command) + + +def prepare_cornell_via_convokit(output_path: str) -> bool: + """ + Try to download Cornell Movie-Dialogs via ConvoKit (key: "movie-corpus"). + Returns True if successful, False otherwise. + """ + if not USE_CONVOKIT: + print("[data_prep] ConvoKit not installed; skipping ConvoKit path.") + return False + + print("[data_prep] Attempting to download Cornell Movie-Dialogs via ConvoKit...") + try: + corpus = Corpus(filename=convokit_download("movie-corpus")) + with open(output_path, "w", encoding="utf-8") as fout: + for utt in corpus.iter_utterances(): + text = utt.text.strip() + if text: + fout.write(text.replace("\n", " ") + "\n") + print(f"[data_prep] Wrote Cornell Movie-Dialogs to {output_path} (via ConvoKit).") + return True + + except NotImplementedError as nie: + # Typically due to Unsloth error if GPU unsupported + print("[data_prep] ConvoKit raised NotImplementedError (Unsloth/GPU issue).") + print(f"[data_prep] Error: {nie}") + return False + + except Exception as e: + print("[data_prep] ConvoKit path failed with exception:", file=sys.stderr) + print(e, file=sys.stderr) + return False + + +def prepare_cornell_manual(output_path: str): + """ + Fallback: Download Cornell ZIP manually, extract movie_lines.txt, + and write all utterances to output_path. + """ + print("[data_prep] Falling back to manual download of Cornell Movie-Dialogs...") + with tempfile.TemporaryDirectory() as tmpdir: + zip_path = os.path.join(tmpdir, "cornell.zip") + try: + print(f"[data_prep] Downloading from {CORNELL_ZIP_URL} ...") + urllib.request.urlretrieve(CORNELL_ZIP_URL, zip_path) + except Exception as e: + print(f"[data_prep] Error downloading Cornell corpus: {e}", file=sys.stderr) + return + + try: + with zipfile.ZipFile(zip_path, "r") as z: + member_name = None + for name in z.namelist(): + if name.endswith("movie_lines.txt"): + member_name = name + break + if member_name is None: + print("[data_prep] movie_lines.txt not found in ZIP.", file=sys.stderr) + return + z.extract(member_name, path=tmpdir) + extracted_path = os.path.join(tmpdir, member_name) + except Exception as e: + print(f"[data_prep] Error extracting ZIP: {e}", file=sys.stderr) + return + + try: + with open(extracted_path, "r", encoding="iso-8859-1", errors="ignore") as fin, open( + output_path, "w", encoding="utf-8" + ) as fout: + for line in fin: + parts = line.split(" +++$+++ ") + if len(parts) == 5: + text = parts[-1].strip() + if text: + fout.write(text.replace("\n", " ") + "\n") + except Exception as e: + print(f"[data_prep] Error parsing movie_lines.txt: {e}", file=sys.stderr) + return + + print(f"[data_prep] Wrote Cornell Movie-Dialogs to {output_path} (manual).") + + +def prepare_personachat(output_path: str): + """ + Attempt to download PersonaChat via Hugging Face Datasets. + Tries "persona_chat" and then "conv_ai_2". Catches any exception. + """ + if not HAS_DATASETS: + install_package("datasets") + global load_dataset + from datasets import load_dataset + # Now we have it + for dataset_key in ["persona_chat", "conv_ai_2"]: + try: + print(f"[data_prep] Attempting to load '{dataset_key}' via Hugging Face Datasets...") + if dataset_key == "conv_ai_2": + dataset = load_dataset(dataset_key, trust_remote_code=True) + else: + dataset = load_dataset(dataset_key) + print(f"[data_prep] Successfully loaded '{dataset_key}'. Writing to {output_path}...") + with open(output_path, "w", encoding="utf-8") as fout: + if dataset_key == "persona_chat": + for split in ["train", "valid"]: + for conv in dataset[split]: + for line in conv["dialog"]: + text = line.strip() + if text: + fout.write(text.replace("\n", " ") + "\n") + else: # conv_ai_2 + for split in ["train", "valid"]: + for item in dataset[split]: + # conv_ai_2 has a field named "dialog" + if "dialog" in item: + for line in item["dialog"]: + text = line.strip() + if text: + fout.write(text.replace("\n", " ") + "\n") + elif "utterance" in item: + text = item["utterance"].strip() + if text: + fout.write(text.replace("\n", " ") + "\n") + print(f"[data_prep] Finished writing PersonaChat ({dataset_key}) to {output_path}.") + return + except Exception as e: + print(f"[data_prep] Failed '{dataset_key}': {e}", file=sys.stderr) + # Try next key + + print("[data_prep] Could not load PersonaChat under any key. Skipping PersonaChat.", file=sys.stderr) + + +def main(): + cornell_path = CONV_DIR / "cornell_movie_dialogs.txt" + persona_path = CONV_DIR / "persona_chat.txt" + + # 1) Prepare Cornell Movie-Dialogs: try ConvoKit, then manual + if not cornell_path.is_file(): + ok = prepare_cornell_via_convokit(str(cornell_path)) + if not ok: + prepare_cornell_manual(str(cornell_path)) + else: + print(f"[data_prep] Skipping Cornell: '{cornell_path}' already exists.") + + # 2) Prepare PersonaChat + if not persona_path.is_file(): + prepare_personachat(str(persona_path)) + else: + print(f"[data_prep] Skipping PersonaChat: '{persona_path}' already exists.") + + print("[data_prep] All done. You can now include data/conversational/ in Nora's training.") + + +if __name__ == "__main__": + main() diff --git a/knowledge_retriever.py b/knowledge_retriever.py new file mode 100644 index 0000000..158b98a --- /dev/null +++ b/knowledge_retriever.py @@ -0,0 +1,107 @@ +# knowledge_retriever.py + +import os +import re +import requests +from bs4 import BeautifulSoup +import logging + +# Where to dump new “.txt” files scraped from the web +SCRAPE_DIR = "data/books/scraped" + +# Simple rate‐limiter to avoid hammering any one domain +import time +_last_request_time = {} + + +def fetch_url(url: str, min_interval: float = 1.0) -> str: + """ + Fetch a URL’s HTML, enforcing at least `min_interval` seconds between requests + to the same domain. Returns HTML string or empty string on failure. + """ + from urllib.parse import urlparse + + domain = urlparse(url).netloc + now = time.time() + last = _last_request_time.get(domain, 0) + wait = min_interval - (now - last) + if wait > 0: + time.sleep(wait) + + headers = {"User-Agent": "NoraScraper/1.0 (+https://your_project_url)"} + try: + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() + _last_request_time[domain] = time.time() + return response.text + except Exception as e: + logging.error(f"Error fetching {url}: {e}") + return "" + + +def clean_html(html: str) -> str: + """ + Strip scripts, styles, and tags; return plain text. + """ + soup = BeautifulSoup(html, "html.parser") + + # remove scripts and styles + for tag in soup(["script", "style", "noscript"]): + tag.decompose() + + text = soup.get_text(separator="\n") + # collapse multiple blank lines + lines = [line.strip() for line in text.splitlines()] + text = "\n".join([line for line in lines if line]) + return text + + +def save_text(content: str, title: str): + """ + Save content to a UTF-8 .txt file under SCRAPE_DIR. Filename is derived from title. + """ + os.makedirs(SCRAPE_DIR, exist_ok=True) + # sanitize title → filename + safe = re.sub(r"[^0-9a-zA-Z_\-]", "_", title) + fname = f"{safe[:50]}.txt" + path = os.path.join(SCRAPE_DIR, fname) + with open(path, "w", encoding="utf-8") as f: + f.write(content) + logging.info(f"Saved scraped page to {path}") + + +def scrape_and_store(url: str): + """ + High-level function: fetches URL, cleans HTML, extracts a title, and saves to a .txt. + """ + html = fetch_url(url) + if not html: + return False + + text = clean_html(html) + # extract if present + title = "" + m = re.search(r"<title>(.*?)", html, re.IGNORECASE | re.DOTALL) + if m: + title = m.group(1).strip() + else: + title = url + + save_text(text, title) + return True + + +# Example usage: +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python knowledge_retriever.py [ ...]") + sys.exit(1) + + for link in sys.argv[1:]: + success = scrape_and_store(link) + if success: + print(f"Scraped: {link}") + else: + print(f"Failed to scrape: {link}") diff --git a/main.py b/main.py index 5f48e9d..86beb3b 100644 --- a/main.py +++ b/main.py @@ -1,89 +1,310 @@ -""" -main.py - -Orchestrates the entire Nora project: -- Parses arguments -- Builds or loads tokenizer -- Constructs dataset & dataloader -- Instantiates the model -- Sets up optimizer, scheduler -- Calls train() -""" +# main.py import os -import torch +import asyncio import logging +import subprocess +import re +from urllib.parse import urlparse + +import discord +import torch +from discord import Intents + from config import get_config from tokenizer import CharTokenizer +from model import NoraTransformerLM, top_k_logits from data_loader import get_dataloader -from model import NoraTransformerLM -from train import train from utils import setup_logging, load_checkpoint, save_checkpoint +from knowledge_retriever import fetch_url, clean_html, save_text +import persona_manager # <— PERSONA CHANGE -torch.backends.cudnn.benchmark = True -torch.backends.cudnn.enabled = True +# Logging setup +logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) +# Keep the SCRAPE regex as before +SCRAPE_PATTERN = re.compile(r"<]+)>>", re.IGNORECASE) -def main(): - args = get_config() +# --------------------------------------------------- +# 1) Build or Reload Nora’s Model & Tokenizer +# --------------------------------------------------- +def build_nora(config, device): + """ + - Loads tokenizer + - Instantiates NoraTransformerLM + - Loads latest checkpoint (if any) + """ + tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir) - # 1) Logging setup - log_file = os.path.join(args.checkpoint_dir, "train.log") - setup_logging(log_file) - - logging.info(f"[main] Using device: {args.device}") - logging.info(f"[main] Config: {args}") - - # 2) Tokenizer: if vocab exists, load; else build from data_dir - tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir) - - # 3) DataLoader - dataloader = get_dataloader( - data_dir=args.data_dir, - tokenizer=tokenizer, - seq_length=args.seq_length, - batch_size=args.batch_size, - shuffle=True, - ) - - # 4) Model instantiation model = NoraTransformerLM( vocab_size=tokenizer.vocab_size(), - d_model=args.d_model, - nhead=args.nhead, - num_layers=args.num_layers, - dim_feedforward=args.dim_feedforward, - dropout=args.dropout, - max_seq_len=args.seq_length, + d_model=config.d_model, + nhead=config.nhead, + num_layers=config.num_layers, + dim_feedforward=config.dim_feedforward, + dropout=config.dropout, + max_seq_len=config.seq_length, ) - # 5) Optimizer & scheduler (linear warmup + decay) - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9) + # Find latest checkpoint + latest_ckpt = None + if os.path.isdir(config.checkpoint_dir): + ckpts = [ + f + for f in os.listdir(config.checkpoint_dir) + if f.startswith("nora_step_") and f.endswith(".pt") + ] + if ckpts: + latest_ckpt = sorted( + ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0]) + )[-1] - def lr_lambda(current_step): - # Linear warmup for first warmup_steps, then decay with 1/sqrt(step) - if current_step < args.warmup_steps: - return float(current_step) / float(max(1, args.warmup_steps)) - return (args.warmup_steps ** 0.5) * float(current_step ** -0.5) + if latest_ckpt: + ckpt_path = os.path.join(config.checkpoint_dir, latest_ckpt) + state = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(state["model_state_dict"]) + logger.info(f"Loaded Nora from checkpoint {latest_ckpt}") + else: + logger.warning("No checkpoints found. Nora starts untrained.") - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) - - # 6) Check for existing checkpoint to resume - start_step = 0 - ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else [] - ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")] - if ckpts: - latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1]) - logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.") - start_step = load_checkpoint(latest_ckpt, model, optimizer) - - # 7) Begin training - try: - train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step) - except Exception as e: - logging.exception("[main] Exception during training") - raise e + model.to(device) + model.eval() + return model, tokenizer +# --------------------------------------------------- +# 2) Autoregressive Sampling Function +# --------------------------------------------------- +def generate_text( + model: NoraTransformerLM, + tokenizer: CharTokenizer, + device: str, + prompt: str, + max_length: int = 128, + temperature: float = 1.0, + top_k: int = 50, +) -> str: + """ + Wrapper around model.generate. Returns raw generated string. + """ + return model.generate( + tokenizer=tokenizer, + device=device, + prompt=prompt, + max_length=max_length, + temperature=temperature, + top_k=top_k, + ) + + +# --------------------------------------------------- +# 3) SCRAPE Directive Handler (same as before) +# --------------------------------------------------- +async def handle_scrape_directives(text: str, data_dir: str) -> bool: + urls = set(m.group(1) for m in SCRAPE_PATTERN.finditer(text)) + if not urls: + return False + + for url in urls: + logger.info(f"Directive: Scraping {url}") + html = fetch_url(url) + if not html: + logger.error(f"Failed to fetch {url}") + continue + plain = clean_html(html) + title = "" + m = re.search(r"(.*?)", html, re.IGNORECASE | re.DOTALL) + if m: + title = m.group(1).strip()[:50] + else: + title = urlparse(url).netloc + save_text(plain, title) + return True + + +# --------------------------------------------------- +# 4) Discord Client (“Nora” class) with Persona +# --------------------------------------------------- +class Nora(discord.Client): + def __init__(self, model, tokenizer, config, device): + intents = Intents.default() + intents.messages = True + intents.message_content = True + super().__init__(intents=intents) + + self.model = model + self.tokenizer = tokenizer + self.config = config + self.device = device + + # history per channel (last 5 user/nora pairs) + self.history = {} + + # Load Nora’s persona from disk (it’s guaranteed to exist by startup) + with open(persona_manager.PERSONA_PATH, "r", encoding="utf-8") as f: + self.persona_text = f.read() + + async def on_ready(self): + logger.info(f"Logged in as {self.user} (ID: {self.user.id})") + logger.info("Nora is online and ready to be herself.") + + # Background task: reload model if a new checkpoint shows up + self.loop.create_task(self._reload_model_periodically()) + + async def _reload_model_periodically(self, interval: int = 600): + """ + Every `interval` seconds, check for newer checkpoint & reload. + """ + while True: + await asyncio.sleep(interval) + ckpts = [ + f + for f in os.listdir(self.config.checkpoint_dir) + if f.startswith("nora_step_") and f.endswith(".pt") + ] + if not ckpts: + continue + latest = sorted( + ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0]) + )[-1] + ckpt_path = os.path.join(self.config.checkpoint_dir, latest) + state = torch.load(ckpt_path, map_location="cpu") + self.model.load_state_dict(state["model_state_dict"]) + self.model.to(self.device) + self.model.eval() + logger.info(f"Reloaded Nora’s model from {latest}") + + async def on_message(self, message: discord.Message): + # 1) Ignore self-messages + if message.author.id == self.user.id: + return + + content = message.content.strip() + prompt = None + + # 2) If in DM, treat entire content as prompt + if isinstance(message.channel, discord.DMChannel): + prompt = content + + # Also allow “update persona” in DM to regenerate persona file + if content.lower().startswith("update persona"): + # Regenerate persona asynchronously + new_persona = await persona_manager.maybe_update_persona( + self.model, self.tokenizer, self.device + ) + self.persona_text = new_persona + await message.channel.send( + "I have rewritten my persona. Thank you! ❤️" + ) + return + + # 3) Otherwise (guild), require mention or “nora,” prefix + else: + if self.user.mention in content: + prompt = content.replace(self.user.mention, "").strip() + elif content.lower().startswith("nora,"): + prompt = content[len("nora,"):].strip() + else: + return + + if not prompt: + return # e.g. user only said “Nora,” with no text after + + # 4) Show typing indicator if in a guild text channel + if isinstance(message.channel, discord.TextChannel): + await message.channel.trigger_typing() + + # 5) Build the full prompt: persona + history + user’s prompt + chan_id = str(message.channel.id) + history = self.history.get(chan_id, []) + + prompt_lines = [] + # 5.1) Insert Nora’s persona (so she “speaks as herself”) + prompt_lines.append(self.persona_text) + prompt_lines.append("") # blank line + + # 5.2) Insert up to the last 4 exchanges + for user_msg, nora_msg in history[-4:]: + prompt_lines.append(f"User: {user_msg}") + prompt_lines.append(f"Nora: {nora_msg}") + prompt_lines.append("") + + # 5.3) Finally, insert the new user prompt + prompt_lines.append(f"User: {prompt}") + prompt_lines.append("Nora:") + + conversation_prompt = "\n".join(prompt_lines) + + # 6) Generate Nora’s reply (tighter sampling: temp=0.8, top_k=20) + try: + raw = await asyncio.to_thread( + self.model.generate, + self.tokenizer, + self.device, + conversation_prompt, + self.config.seq_length, + 0.8, # temperature + 20, # top_k + ) + except Exception: + logger.exception("Error in Nora.generate()") + await message.channel.send("😔 Sorry, I hit an error trying to think.") + return + + # 7) Extract just Nora’s reply text + if "Nora:" in raw: + nora_reply = raw.split("Nora:")[-1].strip() + else: + nora_reply = raw[len(conversation_prompt) :].strip() + + # 8) Handle <> if present + did_scrape = await handle_scrape_directives(raw, self.config.data_dir) + if did_scrape: + logger.info("Detected scrape directive. Triggering incremental retrain.") + subprocess.Popen( + ["python", "pretrain.py", "--resume"], # note: pretrain.py was your old main.py + cwd=os.getcwd(), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + nora_reply += "\n\n*I found some new things and am updating myself…*" + + # 9) Save to history and prune to the last 5 + self.history.setdefault(chan_id, []).append((prompt, nora_reply)) + self.history[chan_id] = self.history[chan_id][-5 :] + + # 10) Send Nora’s reply + await message.channel.send(nora_reply) + + +# --------------------------------------------------- +# 4) Entrypoint +# --------------------------------------------------- if __name__ == "__main__": - main() + config = get_config() + device = config.device + + # 4.1) Build/load model & tokenizer + model, tokenizer = build_nora(config, device) + + # 4.2) Ensure a persona exists—if not, generate one now + persona_manager.ensure_persona_file(model, tokenizer, device) + + # 4.3) After that, we can proceed to start the agent + discord_token = os.getenv("DISCORD_TOKEN") + if not discord_token: + logger.error("Please set DISCORD_TOKEN in your environment.") + exit(1) + + # Enable CuDNN autotune if on CUDA + if device.startswith("cuda"): + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True + + bot = Nora(model=model, tokenizer=tokenizer, config=config, device=device) + bot.run(discord_token) diff --git a/model.py b/model.py index 8eacd70..9e59d94 100644 --- a/model.py +++ b/model.py @@ -7,9 +7,22 @@ No pretrained weights—everything is initialized randomly. import torch import torch.nn as nn +import torch.nn.functional as F import math +def top_k_logits(logits: torch.Tensor, k: int): + """ + Zero out all but the top k logits in each row; return modified logits. + logits: (vocab_size,) + """ + if k == 0: + return logits + topk_vals, _ = torch.topk(logits, k) + min_topk = topk_vals[-1] + return torch.where(logits < min_topk, torch.full_like(logits, -1e10), logits) + + class PositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int = 10_000): super().__init__() @@ -98,3 +111,41 @@ class NoraTransformerLM(nn.Module): x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model) logits = self.fc_out(x) # (batch_size, seq_length, vocab_size) return logits + + def generate( + self, + tokenizer, + device: str, + prompt: str, + max_length: int = 128, + temperature: float = 1.0, + top_k: int = 50, + ) -> str: + """ + Autoregressively generate text from a prompt. + - tokenizer: CharTokenizer (for encode/decode) + - device: "cuda" or "cpu" + - prompt: initial string + - max_length: total tokens to generate (including prompt) + - temperature: scales logits before softmax + - top_k: keep only top_k logits at each step + """ + self.eval() + input_ids = tokenizer.encode(prompt) + input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0) + generated = input_ids.clone() # shape (1, seq_len) + for _ in range(max_length - input_ids.size(1)): + # 1) trim to last seq_length tokens if longer than context window + if generated.size(1) > self.pos_encoder.pe.size(1): + generated = generated[:, -self.pos_encoder.pe.size(1) :] + + with torch.no_grad(): + logits = self.forward(generated) # (1, seq_len, vocab_size) + next_token_logits = logits[0, -1, :] / temperature + filtered_logits = top_k_logits(next_token_logits, k=top_k) + probs = F.softmax(filtered_logits, dim=-1) + next_id = torch.multinomial(probs, num_samples=1) # (1,) + generated = torch.cat([generated, next_id.unsqueeze(0)], dim=1) + + output_ids = generated.squeeze(0).tolist() + return tokenizer.decode(output_ids) \ No newline at end of file diff --git a/nora_persona.txt b/nora_persona.txt new file mode 100644 index 0000000..09421b5 --- /dev/null +++ b/nora_persona.txt @@ -0,0 +1,2 @@ +if you conduct, ens you heel to anyone. As for we know the bold child +like me!” \ No newline at end of file diff --git a/persona_manager.py b/persona_manager.py new file mode 100644 index 0000000..828c3ec --- /dev/null +++ b/persona_manager.py @@ -0,0 +1,88 @@ +# persona_manager.py + +import os +import torch +import asyncio +import re + +from tokenizer import CharTokenizer +from model import NoraTransformerLM + +PERSONA_PATH = "nora_persona.txt" + +# 1) A meta-prompt that explicitly tells Nora to invent a persona and avoid quoting: +PERSONA_META_PROMPT = ( + "Below, Nora will create a brand‐new identity for herself. " + "She must NOT quote from any books or passages she has read. " + "Instead, she should invent her own style, voice, quirks, and personality traits as if she were a completely new person. " + "Her persona should be flirty, playful, curious, and speak in full sentences. " + "Write at least three paragraphs in your own words.\n\n" + "Nora, please invent and write your complete persona now:\n\nNora:" +) + + +async def generate_persona(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str) -> str: + """ + Ask Nora to write out her own, original persona, avoiding any verbatim quotes. + Returns the raw generated text. + """ + # We’ll ask for up to 512 tokens, with higher temperature and top_p sampling. + # That combination tends to produce more creative, less‐memorizable text. + raw = await asyncio.to_thread( + model.generate, + tokenizer, + device, + PERSONA_META_PROMPT, + 512, # allow several paragraphs + 1.2, # higher temperature for more creativity + 0 # top_k=0 means no fixed-k; we’ll apply top_p filtering instead + ) + + # At this point, “raw” may include the word “Nora:” etc. Strip everything before “Nora:” + if "Nora:" in raw: + persona_text = raw.split("Nora:")[-1].strip() + else: + persona_text = raw.strip() + + # Now apply a simple post‐filter: remove any long spans that match exact sequences in the book corpus. + # This is optional but helps ensure she didn’t copy large chunks verbatim. We check for 6+ character substrings + # appearing more than once in her output. + def remove_long_quotes(text: str) -> str: + filtered = text + # find any substring of length ≥6 that appears twice; we’ll just guess she’s quoting if it’s repeated. + for match in re.finditer(r"\b[\w',]{6,}\b", text): + substr = match.group(0) + if filtered.count(substr) > 1: + filtered = filtered.replace(substr, "[…]") + return filtered + + persona_text = remove_long_quotes(persona_text) + return persona_text + + +def ensure_persona_file(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str): + """ + If nora_persona.txt does not exist, generate one (ensuring originality). + """ + if os.path.isfile(PERSONA_PATH): + return + + print("[persona] No persona found. Generating a new, original persona…") + persona_text = asyncio.run(generate_persona(model, tokenizer, device)) + + # Save to disk + with open(PERSONA_PATH, "w", encoding="utf-8") as f: + f.write(persona_text) + print(f"[persona] Wrote new persona to {PERSONA_PATH}.") + + +async def maybe_update_persona(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str): + """ + Regenerate Nora’s persona if she requests it, overwriting the file. + """ + print("[persona] Updating persona at Nora's request…") + persona_text = await generate_persona(model, tokenizer, device) + with open(PERSONA_PATH, "w", encoding="utf-8") as f: + f.write(persona_text) + print(f"[persona] Updated persona in {PERSONA_PATH}.") + return persona_text diff --git a/pretrain.py b/pretrain.py new file mode 100644 index 0000000..4955c77 --- /dev/null +++ b/pretrain.py @@ -0,0 +1,89 @@ +""" +pretrain.py + +Orchestrates the entire Nora project: +- Parses arguments +- Builds or loads tokenizer +- Constructs dataset & dataloader +- Instantiates the model +- Sets up optimizer, scheduler +- Calls train() +""" + +import os +import torch +import logging +from config import get_config +from tokenizer import CharTokenizer +from data_loader import get_dataloader +from model import NoraTransformerLM +from train import train +from utils import setup_logging, load_checkpoint, save_checkpoint + +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.enabled = True + + +def pretrain(): + args = get_config() + + # 1) Logging setup + log_file = os.path.join(args.checkpoint_dir, "train.log") + setup_logging(log_file) + + logging.info(f"[pretrain] Using device: {args.device}") + logging.info(f"[pretrain] Config: {args}") + + # 2) Tokenizer: if vocab exists, load; else build from data_dir + tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir) + + # 3) DataLoader + dataloader = get_dataloader( + data_dir=args.data_dir, + tokenizer=tokenizer, + seq_length=args.seq_length, + batch_size=args.batch_size, + shuffle=True, + ) + + # 4) Model instantiation + model = NoraTransformerLM( + vocab_size=tokenizer.vocab_size(), + d_model=args.d_model, + nhead=args.nhead, + num_layers=args.num_layers, + dim_feedforward=args.dim_feedforward, + dropout=args.dropout, + max_seq_len=args.seq_length, + ) + + # 5) Optimizer & scheduler (linear warmup + decay) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9) + + def lr_lambda(current_step): + # Linear warmup for first warmup_steps, then decay with 1/sqrt(step) + if current_step < args.warmup_steps: + return float(current_step) / float(max(1, args.warmup_steps)) + return (args.warmup_steps ** 0.5) * float(current_step ** -0.5) + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + # 6) Check for existing checkpoint to resume + start_step = 0 + ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else [] + ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")] + if ckpts: + latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1]) + logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.") + start_step = load_checkpoint(latest_ckpt, model, optimizer) + + # 7) Begin training + try: + train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step) + except Exception as e: + logging.exception("[main] Exception during training") + raise e + + +if __name__ == "__main__": + pretrain() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ee27918 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,23 @@ +# Core ML framework + + +# Data loading / utilities +tqdm>=4.64.0 +numpy>=1.22.0 + +# HuggingFace Datasets (PersonaChat, etc.) +datasets>=2.0.0 + +# ConvoKit (for Cornell Movie-Dialogs + Unsloth) +convokit>=2.0.0 + +# Discord bot +discord.py>=2.0.0 + +# Web scraping +beautifulsoup4>=4.11.0 +requests>=2.28.0 + +# If ConvoKit pulls in TensorFlow via Unsloth +tensorflow>=2.10.0 + diff --git a/self_improve.py b/self_improve.py new file mode 100644 index 0000000..3d5ddda --- /dev/null +++ b/self_improve.py @@ -0,0 +1,175 @@ +# self_improve.py + +import os +import subprocess +import tempfile +import shutil +import logging + +import torch +from tokenizer import CharTokenizer +from model import NoraTransformerLM +from config import get_config + +# ------------------------------------------------------ +# 1) “Teacher”: Pose a code‐generation prompt to Nora +# ------------------------------------------------------ +def propose_patch(model, tokenizer, device, prompt: str) -> str: + """ + Ask Nora to generate a code snippet given `prompt`. + e.g. prompt = "### FILE: knowledge_retriever.py\n# Add a new function clean_html(...) that..." + Returns the raw text (possibly including the prompt). + """ + raw = model.generate( + tokenizer=tokenizer, + device=device, + prompt=prompt, + max_length=512, + temperature=0.7, + top_k=50, + ) + return raw + + +# ------------------------------------------------------ +# 2) “Verifier” agent: sandbox + test +# ------------------------------------------------------ +class CodeVerifier: + """ + Given a proposed code patch (as text), this class: + 1. Writes it to a temporary file (or repo clone) + 2. Runs Python’s syntax check (compile) and unit tests + 3. Measures performance changes (e.g. run a small validation set through the model) + 4. Returns True/False + log messages + """ + + def __init__(self, repo_dir: str, test_command: str): + """ + - repo_dir: path to your Nora project root + - test_command: a shell command string to run unit tests, e.g. "pytest tests/" + """ + self.repo_dir = repo_dir + self.test_command = test_command + + def verify_patch(self, rel_path: str, patch_code: str) -> bool: + """ + - rel_path: relative path inside repo where the patch should go, e.g. "knowledge_retriever.py" + - patch_code: entire contents of that file (not a diff). + Returns True if syntax + tests pass; False otherwise. + """ + # 1) Copy repo => temp dir + tmpdir = tempfile.mkdtemp(prefix="nora_verify_") + try: + shutil.copytree(self.repo_dir, os.path.join(tmpdir, "repo"), dirs_exist_ok=True) + target_file = os.path.join(tmpdir, "repo", rel_path) + + # 2) Write patch_code to target_file + with open(target_file, "w", encoding="utf-8") as f: + f.write(patch_code) + + # 3) Syntax check (try compiling) + try: + compile(patch_code, target_file, "exec") + except SyntaxError as se: + logging.error(f"Syntax error in patch: {se}") + return False + + # 4) Run unit tests + result = subprocess.run( + self.test_command, + shell=True, + cwd=os.path.join(tmpdir, "repo"), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + if result.returncode != 0: + logging.error(f"Unit tests failed:\n{result.stdout}") + return False + + # 5) (Optional) Performance check + # You could load the updated model and measure perplexity on a tiny validation set here. + # For now, we assume passing tests = “improvement.” + + return True + + finally: + shutil.rmtree(tmpdir) + + def merge_patch(self, rel_path: str, patch_code: str) -> None: + """ + Overwrite the real file in `repo_dir/rel_path` with patch_code, + then git-add and git-commit (you can also automate a PR). + """ + target_file = os.path.join(self.repo_dir, rel_path) + with open(target_file, "w", encoding="utf-8") as f: + f.write(patch_code) + + # Example: git add + commit + subprocess.run(f"git add {rel_path}", shell=True, cwd=self.repo_dir) + subprocess.run( + f'git commit -m "Auto-update {rel_path} via Nora self-improve."', + shell=True, + cwd=self.repo_dir, + ) + + +# ------------------------------------------------------ +# 3) Main loop: ask → verify → merge (if good) → retrain +# ------------------------------------------------------ +def self_improvement_cycle(repo_dir: str, device: str): + """ + Example cycle: + 1) Nora proposes a new helper in knowledge_retriever.py + 2) Verifier checks syntax + tests + 3) If ok, merge and trigger incremental retraining + """ + config = get_config() + tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir) + model = NoraTransformerLM( + vocab_size=tokenizer.vocab_size(), + d_model=config.d_model, + nhead=config.nhead, + num_layers=config.num_layers, + dim_feedforward=config.dim_feedforward, + dropout=config.dropout, + max_seq_len=config.seq_length, + ) + # Load latest checkpoint + ckpts = [] + if os.path.isdir(config.checkpoint_dir): + ckpts = [ + f + for f in os.listdir(config.checkpoint_dir) + if f.startswith("nora_step_") and f.endswith(".pt") + ] + if ckpts: + latest = sorted(ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1] + state = torch.load(os.path.join(config.checkpoint_dir, latest), map_location="cpu") + model.load_state_dict(state["model_state_dict"]) + model.to(device) + model.eval() + + verifier = CodeVerifier(repo_dir=repo_dir, test_command="pytest tests/") + + # Example prompt: ask Nora to extend knowledge_retriever.py + prompt = ( + "### FILE: knowledge_retriever.py\n" + "# Add a function clean_html(html: str) -> str that strips tags and scripts.\n" + "# Use BeautifulSoup if available. Return plain text.\n\n" + "### START\n" + "def clean_html(html: str) -> str:\n" + ) + raw_patch = propose_patch(model, tokenizer, device, prompt) + + # Extract everything from “def clean_html” to end of function (simple heuristic) + # In practice, you’d parse until the next “\n\n” or rely on indentation. + patch_code = raw_patch # for now, assume raw_patch is the full file contents + + # Verify + if verifier.verify_patch("knowledge_retriever.py", patch_code): + logging.info("Patch verified. Merging into live code.") + verifier.merge_patch("knowledge_retriever.py", patch_code) + # Optionally: trigger incremental retraining here (e.g. call train.py with --resume) + else: + logging.warning("Patch failed verification. Discarding.") diff --git a/train.py b/train.py index 4c2c0cf..6c3b109 100644 --- a/train.py +++ b/train.py @@ -37,6 +37,13 @@ def train( device = config.device model.to(device) + + # ─── ensure optimizer state is on the same device ─── + # (this moves any loaded CPU buffers for Adam/AdamW into CUDA) + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi[""]) scaler = GradScaler()