diff --git a/.gitignore b/.gitignore
index 7299dec..0df970e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -196,3 +196,5 @@ cython_debug/
checkpoints/nora_step_*.pt
data/books
*.json
+data/conversational/cornell_movie_dialogs.txt
+.gitignore
diff --git a/config.py b/config.py
index 0ad435c..41e13ed 100644
--- a/config.py
+++ b/config.py
@@ -67,7 +67,7 @@ def get_config():
parser.add_argument(
"--batch_size",
type=int,
- default=32,
+ default=512,
help="Batch size per training step.",
)
parser.add_argument(
diff --git a/data/conversational/persona_chat.txt b/data/conversational/persona_chat.txt
new file mode 100644
index 0000000..e69de29
diff --git a/data_loader.py b/data_loader.py
index 4a9b5ae..7262868 100644
--- a/data_loader.py
+++ b/data_loader.py
@@ -21,15 +21,25 @@ class TextDataset(Dataset):
self.seq_length = seq_length
self.tokenizer = tokenizer
- # Read and concatenate all text files into one long string
+ # Read and concatenate all .txt files under two folders:
+ # - data/books/
+ # - data/conversational/
texts = []
- for root, _, files in os.walk(data_dir):
- for fname in files:
- if not fname.lower().endswith(".txt"):
- continue
- path = os.path.join(root, fname)
- with open(path, "r", encoding="utf-8", errors="ignore") as f:
- texts.append(f.read())
+ # If data_dir is a single path, we still look for a sibling "conversational" folder
+ conversational_dir = os.path.join(os.path.dirname(data_dir), "conversational")
+ # Gather all folders to walk
+ dirs_to_walk = [data_dir]
+ if os.path.isdir(conversational_dir):
+ dirs_to_walk.append(conversational_dir)
+
+ for d in dirs_to_walk:
+ for root, _, files in os.walk(d):
+ for fname in files:
+ if not fname.lower().endswith(".txt"):
+ continue
+ path = os.path.join(root, fname)
+ with open(path, "r", encoding="utf-8", errors="ignore") as f:
+ texts.append(f.read())
full_text = "\n".join(texts)
token_ids = self.tokenizer.encode(full_text)
diff --git a/data_prep.py b/data_prep.py
new file mode 100644
index 0000000..4c81aa3
--- /dev/null
+++ b/data_prep.py
@@ -0,0 +1,217 @@
+#!/usr/bin/env python3
+"""
+data_prep.py
+
+1) Attempts to download Cornell Movie-Dialogs via ConvoKit (key: "movie-corpus").
+ - If ConvoKit/Unsloth fails, falls back to manual ZIP download/extraction.
+
+2) Attempts to download PersonaChat via Hugging Face Datasets:
+ - First tries "persona_chat" (older key).
+ - If that fails, tries "conv_ai_2" (alias).
+ - Catches any exception to skip gracefully.
+
+3) Writes each utterance to:
+ data/conversational/cornell_movie_dialogs.txt
+ data/conversational/persona_chat.txt
+
+After running, you’ll have:
+ data/
+ ├── books/ (your original Gutenberg .txt files)
+ └── conversational/
+ ├── cornell_movie_dialogs.txt
+ └── persona_chat.txt
+
+Then retrain or fine-tune Nora on data/books/ + data/conversational/.
+"""
+
+import os
+import sys
+import zipfile
+import tempfile
+import urllib.request
+from pathlib import Path
+
+# === 1) Attempt to import ConvoKit for Cornell Movie-Dialogs ===
+USE_CONVOKIT = True
+try:
+ from convokit import Corpus, download as convokit_download
+except ImportError:
+ USE_CONVOKIT = False
+
+# === 2) Attempt to import Hugging Face Datasets ===
+HAS_DATASETS = True
+try:
+ from datasets import load_dataset
+except ImportError:
+ HAS_DATASETS = False
+
+# Directory for conversational data
+CONV_DIR = Path("data/conversational")
+CONV_DIR.mkdir(parents=True, exist_ok=True)
+
+# Official ZIP URL (fallback) for Cornell Movie-Dialogs
+CORNELL_ZIP_URL = "https://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip"
+
+
+def install_package(pkg_name: str):
+ """
+ Installs a Python package using the same Python interpreter,
+ wrapping the path in quotes to handle spaces.
+ """
+ python_executable = sys.executable
+ command = f"\"{python_executable}\" -m pip install {pkg_name}"
+ print(f"[data_prep] Installing package: {pkg_name}")
+ os.system(command)
+
+
+def prepare_cornell_via_convokit(output_path: str) -> bool:
+ """
+ Try to download Cornell Movie-Dialogs via ConvoKit (key: "movie-corpus").
+ Returns True if successful, False otherwise.
+ """
+ if not USE_CONVOKIT:
+ print("[data_prep] ConvoKit not installed; skipping ConvoKit path.")
+ return False
+
+ print("[data_prep] Attempting to download Cornell Movie-Dialogs via ConvoKit...")
+ try:
+ corpus = Corpus(filename=convokit_download("movie-corpus"))
+ with open(output_path, "w", encoding="utf-8") as fout:
+ for utt in corpus.iter_utterances():
+ text = utt.text.strip()
+ if text:
+ fout.write(text.replace("\n", " ") + "\n")
+ print(f"[data_prep] Wrote Cornell Movie-Dialogs to {output_path} (via ConvoKit).")
+ return True
+
+ except NotImplementedError as nie:
+ # Typically due to Unsloth error if GPU unsupported
+ print("[data_prep] ConvoKit raised NotImplementedError (Unsloth/GPU issue).")
+ print(f"[data_prep] Error: {nie}")
+ return False
+
+ except Exception as e:
+ print("[data_prep] ConvoKit path failed with exception:", file=sys.stderr)
+ print(e, file=sys.stderr)
+ return False
+
+
+def prepare_cornell_manual(output_path: str):
+ """
+ Fallback: Download Cornell ZIP manually, extract movie_lines.txt,
+ and write all utterances to output_path.
+ """
+ print("[data_prep] Falling back to manual download of Cornell Movie-Dialogs...")
+ with tempfile.TemporaryDirectory() as tmpdir:
+ zip_path = os.path.join(tmpdir, "cornell.zip")
+ try:
+ print(f"[data_prep] Downloading from {CORNELL_ZIP_URL} ...")
+ urllib.request.urlretrieve(CORNELL_ZIP_URL, zip_path)
+ except Exception as e:
+ print(f"[data_prep] Error downloading Cornell corpus: {e}", file=sys.stderr)
+ return
+
+ try:
+ with zipfile.ZipFile(zip_path, "r") as z:
+ member_name = None
+ for name in z.namelist():
+ if name.endswith("movie_lines.txt"):
+ member_name = name
+ break
+ if member_name is None:
+ print("[data_prep] movie_lines.txt not found in ZIP.", file=sys.stderr)
+ return
+ z.extract(member_name, path=tmpdir)
+ extracted_path = os.path.join(tmpdir, member_name)
+ except Exception as e:
+ print(f"[data_prep] Error extracting ZIP: {e}", file=sys.stderr)
+ return
+
+ try:
+ with open(extracted_path, "r", encoding="iso-8859-1", errors="ignore") as fin, open(
+ output_path, "w", encoding="utf-8"
+ ) as fout:
+ for line in fin:
+ parts = line.split(" +++$+++ ")
+ if len(parts) == 5:
+ text = parts[-1].strip()
+ if text:
+ fout.write(text.replace("\n", " ") + "\n")
+ except Exception as e:
+ print(f"[data_prep] Error parsing movie_lines.txt: {e}", file=sys.stderr)
+ return
+
+ print(f"[data_prep] Wrote Cornell Movie-Dialogs to {output_path} (manual).")
+
+
+def prepare_personachat(output_path: str):
+ """
+ Attempt to download PersonaChat via Hugging Face Datasets.
+ Tries "persona_chat" and then "conv_ai_2". Catches any exception.
+ """
+ if not HAS_DATASETS:
+ install_package("datasets")
+ global load_dataset
+ from datasets import load_dataset
+ # Now we have it
+ for dataset_key in ["persona_chat", "conv_ai_2"]:
+ try:
+ print(f"[data_prep] Attempting to load '{dataset_key}' via Hugging Face Datasets...")
+ if dataset_key == "conv_ai_2":
+ dataset = load_dataset(dataset_key, trust_remote_code=True)
+ else:
+ dataset = load_dataset(dataset_key)
+ print(f"[data_prep] Successfully loaded '{dataset_key}'. Writing to {output_path}...")
+ with open(output_path, "w", encoding="utf-8") as fout:
+ if dataset_key == "persona_chat":
+ for split in ["train", "valid"]:
+ for conv in dataset[split]:
+ for line in conv["dialog"]:
+ text = line.strip()
+ if text:
+ fout.write(text.replace("\n", " ") + "\n")
+ else: # conv_ai_2
+ for split in ["train", "valid"]:
+ for item in dataset[split]:
+ # conv_ai_2 has a field named "dialog"
+ if "dialog" in item:
+ for line in item["dialog"]:
+ text = line.strip()
+ if text:
+ fout.write(text.replace("\n", " ") + "\n")
+ elif "utterance" in item:
+ text = item["utterance"].strip()
+ if text:
+ fout.write(text.replace("\n", " ") + "\n")
+ print(f"[data_prep] Finished writing PersonaChat ({dataset_key}) to {output_path}.")
+ return
+ except Exception as e:
+ print(f"[data_prep] Failed '{dataset_key}': {e}", file=sys.stderr)
+ # Try next key
+
+ print("[data_prep] Could not load PersonaChat under any key. Skipping PersonaChat.", file=sys.stderr)
+
+
+def main():
+ cornell_path = CONV_DIR / "cornell_movie_dialogs.txt"
+ persona_path = CONV_DIR / "persona_chat.txt"
+
+ # 1) Prepare Cornell Movie-Dialogs: try ConvoKit, then manual
+ if not cornell_path.is_file():
+ ok = prepare_cornell_via_convokit(str(cornell_path))
+ if not ok:
+ prepare_cornell_manual(str(cornell_path))
+ else:
+ print(f"[data_prep] Skipping Cornell: '{cornell_path}' already exists.")
+
+ # 2) Prepare PersonaChat
+ if not persona_path.is_file():
+ prepare_personachat(str(persona_path))
+ else:
+ print(f"[data_prep] Skipping PersonaChat: '{persona_path}' already exists.")
+
+ print("[data_prep] All done. You can now include data/conversational/ in Nora's training.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/knowledge_retriever.py b/knowledge_retriever.py
new file mode 100644
index 0000000..158b98a
--- /dev/null
+++ b/knowledge_retriever.py
@@ -0,0 +1,107 @@
+# knowledge_retriever.py
+
+import os
+import re
+import requests
+from bs4 import BeautifulSoup
+import logging
+
+# Where to dump new “.txt” files scraped from the web
+SCRAPE_DIR = "data/books/scraped"
+
+# Simple rate‐limiter to avoid hammering any one domain
+import time
+_last_request_time = {}
+
+
+def fetch_url(url: str, min_interval: float = 1.0) -> str:
+ """
+ Fetch a URL’s HTML, enforcing at least `min_interval` seconds between requests
+ to the same domain. Returns HTML string or empty string on failure.
+ """
+ from urllib.parse import urlparse
+
+ domain = urlparse(url).netloc
+ now = time.time()
+ last = _last_request_time.get(domain, 0)
+ wait = min_interval - (now - last)
+ if wait > 0:
+ time.sleep(wait)
+
+ headers = {"User-Agent": "NoraScraper/1.0 (+https://your_project_url)"}
+ try:
+ response = requests.get(url, headers=headers, timeout=10)
+ response.raise_for_status()
+ _last_request_time[domain] = time.time()
+ return response.text
+ except Exception as e:
+ logging.error(f"Error fetching {url}: {e}")
+ return ""
+
+
+def clean_html(html: str) -> str:
+ """
+ Strip scripts, styles, and tags; return plain text.
+ """
+ soup = BeautifulSoup(html, "html.parser")
+
+ # remove scripts and styles
+ for tag in soup(["script", "style", "noscript"]):
+ tag.decompose()
+
+ text = soup.get_text(separator="\n")
+ # collapse multiple blank lines
+ lines = [line.strip() for line in text.splitlines()]
+ text = "\n".join([line for line in lines if line])
+ return text
+
+
+def save_text(content: str, title: str):
+ """
+ Save content to a UTF-8 .txt file under SCRAPE_DIR. Filename is derived from title.
+ """
+ os.makedirs(SCRAPE_DIR, exist_ok=True)
+ # sanitize title → filename
+ safe = re.sub(r"[^0-9a-zA-Z_\-]", "_", title)
+ fname = f"{safe[:50]}.txt"
+ path = os.path.join(SCRAPE_DIR, fname)
+ with open(path, "w", encoding="utf-8") as f:
+ f.write(content)
+ logging.info(f"Saved scraped page to {path}")
+
+
+def scrape_and_store(url: str):
+ """
+ High-level function: fetches URL, cleans HTML, extracts a title, and saves to a .txt.
+ """
+ html = fetch_url(url)
+ if not html:
+ return False
+
+ text = clean_html(html)
+ # extract
if present
+ title = ""
+ m = re.search(r"(.*?)", html, re.IGNORECASE | re.DOTALL)
+ if m:
+ title = m.group(1).strip()
+ else:
+ title = url
+
+ save_text(text, title)
+ return True
+
+
+# Example usage:
+if __name__ == "__main__":
+ import sys
+
+ if len(sys.argv) < 2:
+ print("Usage: python knowledge_retriever.py [ ...]")
+ sys.exit(1)
+
+ for link in sys.argv[1:]:
+ success = scrape_and_store(link)
+ if success:
+ print(f"Scraped: {link}")
+ else:
+ print(f"Failed to scrape: {link}")
diff --git a/main.py b/main.py
index 5f48e9d..86beb3b 100644
--- a/main.py
+++ b/main.py
@@ -1,89 +1,310 @@
-"""
-main.py
-
-Orchestrates the entire Nora project:
-- Parses arguments
-- Builds or loads tokenizer
-- Constructs dataset & dataloader
-- Instantiates the model
-- Sets up optimizer, scheduler
-- Calls train()
-"""
+# main.py
import os
-import torch
+import asyncio
import logging
+import subprocess
+import re
+from urllib.parse import urlparse
+
+import discord
+import torch
+from discord import Intents
+
from config import get_config
from tokenizer import CharTokenizer
+from model import NoraTransformerLM, top_k_logits
from data_loader import get_dataloader
-from model import NoraTransformerLM
-from train import train
from utils import setup_logging, load_checkpoint, save_checkpoint
+from knowledge_retriever import fetch_url, clean_html, save_text
+import persona_manager # <— PERSONA CHANGE
-torch.backends.cudnn.benchmark = True
-torch.backends.cudnn.enabled = True
+# Logging setup
+logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] [%(levelname)s] %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+logger = logging.getLogger(__name__)
+# Keep the SCRAPE regex as before
+SCRAPE_PATTERN = re.compile(r"<]+)>>", re.IGNORECASE)
-def main():
- args = get_config()
+# ---------------------------------------------------
+# 1) Build or Reload Nora’s Model & Tokenizer
+# ---------------------------------------------------
+def build_nora(config, device):
+ """
+ - Loads tokenizer
+ - Instantiates NoraTransformerLM
+ - Loads latest checkpoint (if any)
+ """
+ tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
- # 1) Logging setup
- log_file = os.path.join(args.checkpoint_dir, "train.log")
- setup_logging(log_file)
-
- logging.info(f"[main] Using device: {args.device}")
- logging.info(f"[main] Config: {args}")
-
- # 2) Tokenizer: if vocab exists, load; else build from data_dir
- tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
-
- # 3) DataLoader
- dataloader = get_dataloader(
- data_dir=args.data_dir,
- tokenizer=tokenizer,
- seq_length=args.seq_length,
- batch_size=args.batch_size,
- shuffle=True,
- )
-
- # 4) Model instantiation
model = NoraTransformerLM(
vocab_size=tokenizer.vocab_size(),
- d_model=args.d_model,
- nhead=args.nhead,
- num_layers=args.num_layers,
- dim_feedforward=args.dim_feedforward,
- dropout=args.dropout,
- max_seq_len=args.seq_length,
+ d_model=config.d_model,
+ nhead=config.nhead,
+ num_layers=config.num_layers,
+ dim_feedforward=config.dim_feedforward,
+ dropout=config.dropout,
+ max_seq_len=config.seq_length,
)
- # 5) Optimizer & scheduler (linear warmup + decay)
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
+ # Find latest checkpoint
+ latest_ckpt = None
+ if os.path.isdir(config.checkpoint_dir):
+ ckpts = [
+ f
+ for f in os.listdir(config.checkpoint_dir)
+ if f.startswith("nora_step_") and f.endswith(".pt")
+ ]
+ if ckpts:
+ latest_ckpt = sorted(
+ ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
+ )[-1]
- def lr_lambda(current_step):
- # Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
- if current_step < args.warmup_steps:
- return float(current_step) / float(max(1, args.warmup_steps))
- return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
+ if latest_ckpt:
+ ckpt_path = os.path.join(config.checkpoint_dir, latest_ckpt)
+ state = torch.load(ckpt_path, map_location="cpu")
+ model.load_state_dict(state["model_state_dict"])
+ logger.info(f"Loaded Nora from checkpoint {latest_ckpt}")
+ else:
+ logger.warning("No checkpoints found. Nora starts untrained.")
- scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
-
- # 6) Check for existing checkpoint to resume
- start_step = 0
- ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
- ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
- if ckpts:
- latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
- logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
- start_step = load_checkpoint(latest_ckpt, model, optimizer)
-
- # 7) Begin training
- try:
- train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
- except Exception as e:
- logging.exception("[main] Exception during training")
- raise e
+ model.to(device)
+ model.eval()
+ return model, tokenizer
+# ---------------------------------------------------
+# 2) Autoregressive Sampling Function
+# ---------------------------------------------------
+def generate_text(
+ model: NoraTransformerLM,
+ tokenizer: CharTokenizer,
+ device: str,
+ prompt: str,
+ max_length: int = 128,
+ temperature: float = 1.0,
+ top_k: int = 50,
+) -> str:
+ """
+ Wrapper around model.generate. Returns raw generated string.
+ """
+ return model.generate(
+ tokenizer=tokenizer,
+ device=device,
+ prompt=prompt,
+ max_length=max_length,
+ temperature=temperature,
+ top_k=top_k,
+ )
+
+
+# ---------------------------------------------------
+# 3) SCRAPE Directive Handler (same as before)
+# ---------------------------------------------------
+async def handle_scrape_directives(text: str, data_dir: str) -> bool:
+ urls = set(m.group(1) for m in SCRAPE_PATTERN.finditer(text))
+ if not urls:
+ return False
+
+ for url in urls:
+ logger.info(f"Directive: Scraping {url}")
+ html = fetch_url(url)
+ if not html:
+ logger.error(f"Failed to fetch {url}")
+ continue
+ plain = clean_html(html)
+ title = ""
+ m = re.search(r"(.*?)", html, re.IGNORECASE | re.DOTALL)
+ if m:
+ title = m.group(1).strip()[:50]
+ else:
+ title = urlparse(url).netloc
+ save_text(plain, title)
+ return True
+
+
+# ---------------------------------------------------
+# 4) Discord Client (“Nora” class) with Persona
+# ---------------------------------------------------
+class Nora(discord.Client):
+ def __init__(self, model, tokenizer, config, device):
+ intents = Intents.default()
+ intents.messages = True
+ intents.message_content = True
+ super().__init__(intents=intents)
+
+ self.model = model
+ self.tokenizer = tokenizer
+ self.config = config
+ self.device = device
+
+ # history per channel (last 5 user/nora pairs)
+ self.history = {}
+
+ # Load Nora’s persona from disk (it’s guaranteed to exist by startup)
+ with open(persona_manager.PERSONA_PATH, "r", encoding="utf-8") as f:
+ self.persona_text = f.read()
+
+ async def on_ready(self):
+ logger.info(f"Logged in as {self.user} (ID: {self.user.id})")
+ logger.info("Nora is online and ready to be herself.")
+
+ # Background task: reload model if a new checkpoint shows up
+ self.loop.create_task(self._reload_model_periodically())
+
+ async def _reload_model_periodically(self, interval: int = 600):
+ """
+ Every `interval` seconds, check for newer checkpoint & reload.
+ """
+ while True:
+ await asyncio.sleep(interval)
+ ckpts = [
+ f
+ for f in os.listdir(self.config.checkpoint_dir)
+ if f.startswith("nora_step_") and f.endswith(".pt")
+ ]
+ if not ckpts:
+ continue
+ latest = sorted(
+ ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
+ )[-1]
+ ckpt_path = os.path.join(self.config.checkpoint_dir, latest)
+ state = torch.load(ckpt_path, map_location="cpu")
+ self.model.load_state_dict(state["model_state_dict"])
+ self.model.to(self.device)
+ self.model.eval()
+ logger.info(f"Reloaded Nora’s model from {latest}")
+
+ async def on_message(self, message: discord.Message):
+ # 1) Ignore self-messages
+ if message.author.id == self.user.id:
+ return
+
+ content = message.content.strip()
+ prompt = None
+
+ # 2) If in DM, treat entire content as prompt
+ if isinstance(message.channel, discord.DMChannel):
+ prompt = content
+
+ # Also allow “update persona” in DM to regenerate persona file
+ if content.lower().startswith("update persona"):
+ # Regenerate persona asynchronously
+ new_persona = await persona_manager.maybe_update_persona(
+ self.model, self.tokenizer, self.device
+ )
+ self.persona_text = new_persona
+ await message.channel.send(
+ "I have rewritten my persona. Thank you! ❤️"
+ )
+ return
+
+ # 3) Otherwise (guild), require mention or “nora,” prefix
+ else:
+ if self.user.mention in content:
+ prompt = content.replace(self.user.mention, "").strip()
+ elif content.lower().startswith("nora,"):
+ prompt = content[len("nora,"):].strip()
+ else:
+ return
+
+ if not prompt:
+ return # e.g. user only said “Nora,” with no text after
+
+ # 4) Show typing indicator if in a guild text channel
+ if isinstance(message.channel, discord.TextChannel):
+ await message.channel.trigger_typing()
+
+ # 5) Build the full prompt: persona + history + user’s prompt
+ chan_id = str(message.channel.id)
+ history = self.history.get(chan_id, [])
+
+ prompt_lines = []
+ # 5.1) Insert Nora’s persona (so she “speaks as herself”)
+ prompt_lines.append(self.persona_text)
+ prompt_lines.append("") # blank line
+
+ # 5.2) Insert up to the last 4 exchanges
+ for user_msg, nora_msg in history[-4:]:
+ prompt_lines.append(f"User: {user_msg}")
+ prompt_lines.append(f"Nora: {nora_msg}")
+ prompt_lines.append("")
+
+ # 5.3) Finally, insert the new user prompt
+ prompt_lines.append(f"User: {prompt}")
+ prompt_lines.append("Nora:")
+
+ conversation_prompt = "\n".join(prompt_lines)
+
+ # 6) Generate Nora’s reply (tighter sampling: temp=0.8, top_k=20)
+ try:
+ raw = await asyncio.to_thread(
+ self.model.generate,
+ self.tokenizer,
+ self.device,
+ conversation_prompt,
+ self.config.seq_length,
+ 0.8, # temperature
+ 20, # top_k
+ )
+ except Exception:
+ logger.exception("Error in Nora.generate()")
+ await message.channel.send("😔 Sorry, I hit an error trying to think.")
+ return
+
+ # 7) Extract just Nora’s reply text
+ if "Nora:" in raw:
+ nora_reply = raw.split("Nora:")[-1].strip()
+ else:
+ nora_reply = raw[len(conversation_prompt) :].strip()
+
+ # 8) Handle <> if present
+ did_scrape = await handle_scrape_directives(raw, self.config.data_dir)
+ if did_scrape:
+ logger.info("Detected scrape directive. Triggering incremental retrain.")
+ subprocess.Popen(
+ ["python", "pretrain.py", "--resume"], # note: pretrain.py was your old main.py
+ cwd=os.getcwd(),
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ )
+ nora_reply += "\n\n*I found some new things and am updating myself…*"
+
+ # 9) Save to history and prune to the last 5
+ self.history.setdefault(chan_id, []).append((prompt, nora_reply))
+ self.history[chan_id] = self.history[chan_id][-5 :]
+
+ # 10) Send Nora’s reply
+ await message.channel.send(nora_reply)
+
+
+# ---------------------------------------------------
+# 4) Entrypoint
+# ---------------------------------------------------
if __name__ == "__main__":
- main()
+ config = get_config()
+ device = config.device
+
+ # 4.1) Build/load model & tokenizer
+ model, tokenizer = build_nora(config, device)
+
+ # 4.2) Ensure a persona exists—if not, generate one now
+ persona_manager.ensure_persona_file(model, tokenizer, device)
+
+ # 4.3) After that, we can proceed to start the agent
+ discord_token = os.getenv("DISCORD_TOKEN")
+ if not discord_token:
+ logger.error("Please set DISCORD_TOKEN in your environment.")
+ exit(1)
+
+ # Enable CuDNN autotune if on CUDA
+ if device.startswith("cuda"):
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.enabled = True
+
+ bot = Nora(model=model, tokenizer=tokenizer, config=config, device=device)
+ bot.run(discord_token)
diff --git a/model.py b/model.py
index 8eacd70..9e59d94 100644
--- a/model.py
+++ b/model.py
@@ -7,9 +7,22 @@ No pretrained weights—everything is initialized randomly.
import torch
import torch.nn as nn
+import torch.nn.functional as F
import math
+def top_k_logits(logits: torch.Tensor, k: int):
+ """
+ Zero out all but the top k logits in each row; return modified logits.
+ logits: (vocab_size,)
+ """
+ if k == 0:
+ return logits
+ topk_vals, _ = torch.topk(logits, k)
+ min_topk = topk_vals[-1]
+ return torch.where(logits < min_topk, torch.full_like(logits, -1e10), logits)
+
+
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 10_000):
super().__init__()
@@ -98,3 +111,41 @@ class NoraTransformerLM(nn.Module):
x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model)
logits = self.fc_out(x) # (batch_size, seq_length, vocab_size)
return logits
+
+ def generate(
+ self,
+ tokenizer,
+ device: str,
+ prompt: str,
+ max_length: int = 128,
+ temperature: float = 1.0,
+ top_k: int = 50,
+ ) -> str:
+ """
+ Autoregressively generate text from a prompt.
+ - tokenizer: CharTokenizer (for encode/decode)
+ - device: "cuda" or "cpu"
+ - prompt: initial string
+ - max_length: total tokens to generate (including prompt)
+ - temperature: scales logits before softmax
+ - top_k: keep only top_k logits at each step
+ """
+ self.eval()
+ input_ids = tokenizer.encode(prompt)
+ input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
+ generated = input_ids.clone() # shape (1, seq_len)
+ for _ in range(max_length - input_ids.size(1)):
+ # 1) trim to last seq_length tokens if longer than context window
+ if generated.size(1) > self.pos_encoder.pe.size(1):
+ generated = generated[:, -self.pos_encoder.pe.size(1) :]
+
+ with torch.no_grad():
+ logits = self.forward(generated) # (1, seq_len, vocab_size)
+ next_token_logits = logits[0, -1, :] / temperature
+ filtered_logits = top_k_logits(next_token_logits, k=top_k)
+ probs = F.softmax(filtered_logits, dim=-1)
+ next_id = torch.multinomial(probs, num_samples=1) # (1,)
+ generated = torch.cat([generated, next_id.unsqueeze(0)], dim=1)
+
+ output_ids = generated.squeeze(0).tolist()
+ return tokenizer.decode(output_ids)
\ No newline at end of file
diff --git a/nora_persona.txt b/nora_persona.txt
new file mode 100644
index 0000000..09421b5
--- /dev/null
+++ b/nora_persona.txt
@@ -0,0 +1,2 @@
+if you conduct, ens you heel to anyone. As for we know the bold child
+like me!”
\ No newline at end of file
diff --git a/persona_manager.py b/persona_manager.py
new file mode 100644
index 0000000..828c3ec
--- /dev/null
+++ b/persona_manager.py
@@ -0,0 +1,88 @@
+# persona_manager.py
+
+import os
+import torch
+import asyncio
+import re
+
+from tokenizer import CharTokenizer
+from model import NoraTransformerLM
+
+PERSONA_PATH = "nora_persona.txt"
+
+# 1) A meta-prompt that explicitly tells Nora to invent a persona and avoid quoting:
+PERSONA_META_PROMPT = (
+ "Below, Nora will create a brand‐new identity for herself. "
+ "She must NOT quote from any books or passages she has read. "
+ "Instead, she should invent her own style, voice, quirks, and personality traits as if she were a completely new person. "
+ "Her persona should be flirty, playful, curious, and speak in full sentences. "
+ "Write at least three paragraphs in your own words.\n\n"
+ "Nora, please invent and write your complete persona now:\n\nNora:"
+)
+
+
+async def generate_persona(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str) -> str:
+ """
+ Ask Nora to write out her own, original persona, avoiding any verbatim quotes.
+ Returns the raw generated text.
+ """
+ # We’ll ask for up to 512 tokens, with higher temperature and top_p sampling.
+ # That combination tends to produce more creative, less‐memorizable text.
+ raw = await asyncio.to_thread(
+ model.generate,
+ tokenizer,
+ device,
+ PERSONA_META_PROMPT,
+ 512, # allow several paragraphs
+ 1.2, # higher temperature for more creativity
+ 0 # top_k=0 means no fixed-k; we’ll apply top_p filtering instead
+ )
+
+ # At this point, “raw” may include the word “Nora:” etc. Strip everything before “Nora:”
+ if "Nora:" in raw:
+ persona_text = raw.split("Nora:")[-1].strip()
+ else:
+ persona_text = raw.strip()
+
+ # Now apply a simple post‐filter: remove any long spans that match exact sequences in the book corpus.
+ # This is optional but helps ensure she didn’t copy large chunks verbatim. We check for 6+ character substrings
+ # appearing more than once in her output.
+ def remove_long_quotes(text: str) -> str:
+ filtered = text
+ # find any substring of length ≥6 that appears twice; we’ll just guess she’s quoting if it’s repeated.
+ for match in re.finditer(r"\b[\w',]{6,}\b", text):
+ substr = match.group(0)
+ if filtered.count(substr) > 1:
+ filtered = filtered.replace(substr, "[…]")
+ return filtered
+
+ persona_text = remove_long_quotes(persona_text)
+ return persona_text
+
+
+def ensure_persona_file(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str):
+ """
+ If nora_persona.txt does not exist, generate one (ensuring originality).
+ """
+ if os.path.isfile(PERSONA_PATH):
+ return
+
+ print("[persona] No persona found. Generating a new, original persona…")
+ persona_text = asyncio.run(generate_persona(model, tokenizer, device))
+
+ # Save to disk
+ with open(PERSONA_PATH, "w", encoding="utf-8") as f:
+ f.write(persona_text)
+ print(f"[persona] Wrote new persona to {PERSONA_PATH}.")
+
+
+async def maybe_update_persona(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str):
+ """
+ Regenerate Nora’s persona if she requests it, overwriting the file.
+ """
+ print("[persona] Updating persona at Nora's request…")
+ persona_text = await generate_persona(model, tokenizer, device)
+ with open(PERSONA_PATH, "w", encoding="utf-8") as f:
+ f.write(persona_text)
+ print(f"[persona] Updated persona in {PERSONA_PATH}.")
+ return persona_text
diff --git a/pretrain.py b/pretrain.py
new file mode 100644
index 0000000..4955c77
--- /dev/null
+++ b/pretrain.py
@@ -0,0 +1,89 @@
+"""
+pretrain.py
+
+Orchestrates the entire Nora project:
+- Parses arguments
+- Builds or loads tokenizer
+- Constructs dataset & dataloader
+- Instantiates the model
+- Sets up optimizer, scheduler
+- Calls train()
+"""
+
+import os
+import torch
+import logging
+from config import get_config
+from tokenizer import CharTokenizer
+from data_loader import get_dataloader
+from model import NoraTransformerLM
+from train import train
+from utils import setup_logging, load_checkpoint, save_checkpoint
+
+torch.backends.cudnn.benchmark = True
+torch.backends.cudnn.enabled = True
+
+
+def pretrain():
+ args = get_config()
+
+ # 1) Logging setup
+ log_file = os.path.join(args.checkpoint_dir, "train.log")
+ setup_logging(log_file)
+
+ logging.info(f"[pretrain] Using device: {args.device}")
+ logging.info(f"[pretrain] Config: {args}")
+
+ # 2) Tokenizer: if vocab exists, load; else build from data_dir
+ tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
+
+ # 3) DataLoader
+ dataloader = get_dataloader(
+ data_dir=args.data_dir,
+ tokenizer=tokenizer,
+ seq_length=args.seq_length,
+ batch_size=args.batch_size,
+ shuffle=True,
+ )
+
+ # 4) Model instantiation
+ model = NoraTransformerLM(
+ vocab_size=tokenizer.vocab_size(),
+ d_model=args.d_model,
+ nhead=args.nhead,
+ num_layers=args.num_layers,
+ dim_feedforward=args.dim_feedforward,
+ dropout=args.dropout,
+ max_seq_len=args.seq_length,
+ )
+
+ # 5) Optimizer & scheduler (linear warmup + decay)
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
+
+ def lr_lambda(current_step):
+ # Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
+ if current_step < args.warmup_steps:
+ return float(current_step) / float(max(1, args.warmup_steps))
+ return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
+
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
+
+ # 6) Check for existing checkpoint to resume
+ start_step = 0
+ ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
+ ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
+ if ckpts:
+ latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
+ logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
+ start_step = load_checkpoint(latest_ckpt, model, optimizer)
+
+ # 7) Begin training
+ try:
+ train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
+ except Exception as e:
+ logging.exception("[main] Exception during training")
+ raise e
+
+
+if __name__ == "__main__":
+ pretrain()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..ee27918
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,23 @@
+# Core ML framework
+
+
+# Data loading / utilities
+tqdm>=4.64.0
+numpy>=1.22.0
+
+# HuggingFace Datasets (PersonaChat, etc.)
+datasets>=2.0.0
+
+# ConvoKit (for Cornell Movie-Dialogs + Unsloth)
+convokit>=2.0.0
+
+# Discord bot
+discord.py>=2.0.0
+
+# Web scraping
+beautifulsoup4>=4.11.0
+requests>=2.28.0
+
+# If ConvoKit pulls in TensorFlow via Unsloth
+tensorflow>=2.10.0
+
diff --git a/self_improve.py b/self_improve.py
new file mode 100644
index 0000000..3d5ddda
--- /dev/null
+++ b/self_improve.py
@@ -0,0 +1,175 @@
+# self_improve.py
+
+import os
+import subprocess
+import tempfile
+import shutil
+import logging
+
+import torch
+from tokenizer import CharTokenizer
+from model import NoraTransformerLM
+from config import get_config
+
+# ------------------------------------------------------
+# 1) “Teacher”: Pose a code‐generation prompt to Nora
+# ------------------------------------------------------
+def propose_patch(model, tokenizer, device, prompt: str) -> str:
+ """
+ Ask Nora to generate a code snippet given `prompt`.
+ e.g. prompt = "### FILE: knowledge_retriever.py\n# Add a new function clean_html(...) that..."
+ Returns the raw text (possibly including the prompt).
+ """
+ raw = model.generate(
+ tokenizer=tokenizer,
+ device=device,
+ prompt=prompt,
+ max_length=512,
+ temperature=0.7,
+ top_k=50,
+ )
+ return raw
+
+
+# ------------------------------------------------------
+# 2) “Verifier” agent: sandbox + test
+# ------------------------------------------------------
+class CodeVerifier:
+ """
+ Given a proposed code patch (as text), this class:
+ 1. Writes it to a temporary file (or repo clone)
+ 2. Runs Python’s syntax check (compile) and unit tests
+ 3. Measures performance changes (e.g. run a small validation set through the model)
+ 4. Returns True/False + log messages
+ """
+
+ def __init__(self, repo_dir: str, test_command: str):
+ """
+ - repo_dir: path to your Nora project root
+ - test_command: a shell command string to run unit tests, e.g. "pytest tests/"
+ """
+ self.repo_dir = repo_dir
+ self.test_command = test_command
+
+ def verify_patch(self, rel_path: str, patch_code: str) -> bool:
+ """
+ - rel_path: relative path inside repo where the patch should go, e.g. "knowledge_retriever.py"
+ - patch_code: entire contents of that file (not a diff).
+ Returns True if syntax + tests pass; False otherwise.
+ """
+ # 1) Copy repo => temp dir
+ tmpdir = tempfile.mkdtemp(prefix="nora_verify_")
+ try:
+ shutil.copytree(self.repo_dir, os.path.join(tmpdir, "repo"), dirs_exist_ok=True)
+ target_file = os.path.join(tmpdir, "repo", rel_path)
+
+ # 2) Write patch_code to target_file
+ with open(target_file, "w", encoding="utf-8") as f:
+ f.write(patch_code)
+
+ # 3) Syntax check (try compiling)
+ try:
+ compile(patch_code, target_file, "exec")
+ except SyntaxError as se:
+ logging.error(f"Syntax error in patch: {se}")
+ return False
+
+ # 4) Run unit tests
+ result = subprocess.run(
+ self.test_command,
+ shell=True,
+ cwd=os.path.join(tmpdir, "repo"),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ )
+ if result.returncode != 0:
+ logging.error(f"Unit tests failed:\n{result.stdout}")
+ return False
+
+ # 5) (Optional) Performance check
+ # You could load the updated model and measure perplexity on a tiny validation set here.
+ # For now, we assume passing tests = “improvement.”
+
+ return True
+
+ finally:
+ shutil.rmtree(tmpdir)
+
+ def merge_patch(self, rel_path: str, patch_code: str) -> None:
+ """
+ Overwrite the real file in `repo_dir/rel_path` with patch_code,
+ then git-add and git-commit (you can also automate a PR).
+ """
+ target_file = os.path.join(self.repo_dir, rel_path)
+ with open(target_file, "w", encoding="utf-8") as f:
+ f.write(patch_code)
+
+ # Example: git add + commit
+ subprocess.run(f"git add {rel_path}", shell=True, cwd=self.repo_dir)
+ subprocess.run(
+ f'git commit -m "Auto-update {rel_path} via Nora self-improve."',
+ shell=True,
+ cwd=self.repo_dir,
+ )
+
+
+# ------------------------------------------------------
+# 3) Main loop: ask → verify → merge (if good) → retrain
+# ------------------------------------------------------
+def self_improvement_cycle(repo_dir: str, device: str):
+ """
+ Example cycle:
+ 1) Nora proposes a new helper in knowledge_retriever.py
+ 2) Verifier checks syntax + tests
+ 3) If ok, merge and trigger incremental retraining
+ """
+ config = get_config()
+ tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
+ model = NoraTransformerLM(
+ vocab_size=tokenizer.vocab_size(),
+ d_model=config.d_model,
+ nhead=config.nhead,
+ num_layers=config.num_layers,
+ dim_feedforward=config.dim_feedforward,
+ dropout=config.dropout,
+ max_seq_len=config.seq_length,
+ )
+ # Load latest checkpoint
+ ckpts = []
+ if os.path.isdir(config.checkpoint_dir):
+ ckpts = [
+ f
+ for f in os.listdir(config.checkpoint_dir)
+ if f.startswith("nora_step_") and f.endswith(".pt")
+ ]
+ if ckpts:
+ latest = sorted(ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1]
+ state = torch.load(os.path.join(config.checkpoint_dir, latest), map_location="cpu")
+ model.load_state_dict(state["model_state_dict"])
+ model.to(device)
+ model.eval()
+
+ verifier = CodeVerifier(repo_dir=repo_dir, test_command="pytest tests/")
+
+ # Example prompt: ask Nora to extend knowledge_retriever.py
+ prompt = (
+ "### FILE: knowledge_retriever.py\n"
+ "# Add a function clean_html(html: str) -> str that strips tags and scripts.\n"
+ "# Use BeautifulSoup if available. Return plain text.\n\n"
+ "### START\n"
+ "def clean_html(html: str) -> str:\n"
+ )
+ raw_patch = propose_patch(model, tokenizer, device, prompt)
+
+ # Extract everything from “def clean_html” to end of function (simple heuristic)
+ # In practice, you’d parse until the next “\n\n” or rely on indentation.
+ patch_code = raw_patch # for now, assume raw_patch is the full file contents
+
+ # Verify
+ if verifier.verify_patch("knowledge_retriever.py", patch_code):
+ logging.info("Patch verified. Merging into live code.")
+ verifier.merge_patch("knowledge_retriever.py", patch_code)
+ # Optionally: trigger incremental retraining here (e.g. call train.py with --resume)
+ else:
+ logging.warning("Patch failed verification. Discarding.")
diff --git a/train.py b/train.py
index 4c2c0cf..6c3b109 100644
--- a/train.py
+++ b/train.py
@@ -37,6 +37,13 @@ def train(
device = config.device
model.to(device)
+
+ # ─── ensure optimizer state is on the same device ───
+ # (this moves any loaded CPU buffers for Adam/AdamW into CUDA)
+ for state in optimizer.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi[""])
scaler = GradScaler()