Added another learning source for Nora. Also added the requirements.

This commit is contained in:
2025-06-09 14:25:11 -04:00
parent da23742671
commit 5d53ba7cb8
14 changed files with 1070 additions and 78 deletions

2
.gitignore vendored
View File

@ -196,3 +196,5 @@ cython_debug/
checkpoints/nora_step_*.pt checkpoints/nora_step_*.pt
data/books data/books
*.json *.json
data/conversational/cornell_movie_dialogs.txt
.gitignore

View File

@ -67,7 +67,7 @@ def get_config():
parser.add_argument( parser.add_argument(
"--batch_size", "--batch_size",
type=int, type=int,
default=32, default=512,
help="Batch size per training step.", help="Batch size per training step.",
) )
parser.add_argument( parser.add_argument(

View File

View File

@ -21,15 +21,25 @@ class TextDataset(Dataset):
self.seq_length = seq_length self.seq_length = seq_length
self.tokenizer = tokenizer self.tokenizer = tokenizer
# Read and concatenate all text files into one long string # Read and concatenate all .txt files under two folders:
# - data/books/
# - data/conversational/
texts = [] texts = []
for root, _, files in os.walk(data_dir): # If data_dir is a single path, we still look for a sibling "conversational" folder
for fname in files: conversational_dir = os.path.join(os.path.dirname(data_dir), "conversational")
if not fname.lower().endswith(".txt"): # Gather all folders to walk
continue dirs_to_walk = [data_dir]
path = os.path.join(root, fname) if os.path.isdir(conversational_dir):
with open(path, "r", encoding="utf-8", errors="ignore") as f: dirs_to_walk.append(conversational_dir)
texts.append(f.read())
for d in dirs_to_walk:
for root, _, files in os.walk(d):
for fname in files:
if not fname.lower().endswith(".txt"):
continue
path = os.path.join(root, fname)
with open(path, "r", encoding="utf-8", errors="ignore") as f:
texts.append(f.read())
full_text = "\n".join(texts) full_text = "\n".join(texts)
token_ids = self.tokenizer.encode(full_text) token_ids = self.tokenizer.encode(full_text)

217
data_prep.py Normal file
View File

@ -0,0 +1,217 @@
#!/usr/bin/env python3
"""
data_prep.py
1) Attempts to download Cornell Movie-Dialogs via ConvoKit (key: "movie-corpus").
- If ConvoKit/Unsloth fails, falls back to manual ZIP download/extraction.
2) Attempts to download PersonaChat via Hugging Face Datasets:
- First tries "persona_chat" (older key).
- If that fails, tries "conv_ai_2" (alias).
- Catches any exception to skip gracefully.
3) Writes each utterance to:
data/conversational/cornell_movie_dialogs.txt
data/conversational/persona_chat.txt
After running, youll have:
data/
├── books/ (your original Gutenberg .txt files)
└── conversational/
├── cornell_movie_dialogs.txt
└── persona_chat.txt
Then retrain or fine-tune Nora on data/books/ + data/conversational/.
"""
import os
import sys
import zipfile
import tempfile
import urllib.request
from pathlib import Path
# === 1) Attempt to import ConvoKit for Cornell Movie-Dialogs ===
USE_CONVOKIT = True
try:
from convokit import Corpus, download as convokit_download
except ImportError:
USE_CONVOKIT = False
# === 2) Attempt to import Hugging Face Datasets ===
HAS_DATASETS = True
try:
from datasets import load_dataset
except ImportError:
HAS_DATASETS = False
# Directory for conversational data
CONV_DIR = Path("data/conversational")
CONV_DIR.mkdir(parents=True, exist_ok=True)
# Official ZIP URL (fallback) for Cornell Movie-Dialogs
CORNELL_ZIP_URL = "https://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip"
def install_package(pkg_name: str):
"""
Installs a Python package using the same Python interpreter,
wrapping the path in quotes to handle spaces.
"""
python_executable = sys.executable
command = f"\"{python_executable}\" -m pip install {pkg_name}"
print(f"[data_prep] Installing package: {pkg_name}")
os.system(command)
def prepare_cornell_via_convokit(output_path: str) -> bool:
"""
Try to download Cornell Movie-Dialogs via ConvoKit (key: "movie-corpus").
Returns True if successful, False otherwise.
"""
if not USE_CONVOKIT:
print("[data_prep] ConvoKit not installed; skipping ConvoKit path.")
return False
print("[data_prep] Attempting to download Cornell Movie-Dialogs via ConvoKit...")
try:
corpus = Corpus(filename=convokit_download("movie-corpus"))
with open(output_path, "w", encoding="utf-8") as fout:
for utt in corpus.iter_utterances():
text = utt.text.strip()
if text:
fout.write(text.replace("\n", " ") + "\n")
print(f"[data_prep] Wrote Cornell Movie-Dialogs to {output_path} (via ConvoKit).")
return True
except NotImplementedError as nie:
# Typically due to Unsloth error if GPU unsupported
print("[data_prep] ConvoKit raised NotImplementedError (Unsloth/GPU issue).")
print(f"[data_prep] Error: {nie}")
return False
except Exception as e:
print("[data_prep] ConvoKit path failed with exception:", file=sys.stderr)
print(e, file=sys.stderr)
return False
def prepare_cornell_manual(output_path: str):
"""
Fallback: Download Cornell ZIP manually, extract movie_lines.txt,
and write all utterances to output_path.
"""
print("[data_prep] Falling back to manual download of Cornell Movie-Dialogs...")
with tempfile.TemporaryDirectory() as tmpdir:
zip_path = os.path.join(tmpdir, "cornell.zip")
try:
print(f"[data_prep] Downloading from {CORNELL_ZIP_URL} ...")
urllib.request.urlretrieve(CORNELL_ZIP_URL, zip_path)
except Exception as e:
print(f"[data_prep] Error downloading Cornell corpus: {e}", file=sys.stderr)
return
try:
with zipfile.ZipFile(zip_path, "r") as z:
member_name = None
for name in z.namelist():
if name.endswith("movie_lines.txt"):
member_name = name
break
if member_name is None:
print("[data_prep] movie_lines.txt not found in ZIP.", file=sys.stderr)
return
z.extract(member_name, path=tmpdir)
extracted_path = os.path.join(tmpdir, member_name)
except Exception as e:
print(f"[data_prep] Error extracting ZIP: {e}", file=sys.stderr)
return
try:
with open(extracted_path, "r", encoding="iso-8859-1", errors="ignore") as fin, open(
output_path, "w", encoding="utf-8"
) as fout:
for line in fin:
parts = line.split(" +++$+++ ")
if len(parts) == 5:
text = parts[-1].strip()
if text:
fout.write(text.replace("\n", " ") + "\n")
except Exception as e:
print(f"[data_prep] Error parsing movie_lines.txt: {e}", file=sys.stderr)
return
print(f"[data_prep] Wrote Cornell Movie-Dialogs to {output_path} (manual).")
def prepare_personachat(output_path: str):
"""
Attempt to download PersonaChat via Hugging Face Datasets.
Tries "persona_chat" and then "conv_ai_2". Catches any exception.
"""
if not HAS_DATASETS:
install_package("datasets")
global load_dataset
from datasets import load_dataset
# Now we have it
for dataset_key in ["persona_chat", "conv_ai_2"]:
try:
print(f"[data_prep] Attempting to load '{dataset_key}' via Hugging Face Datasets...")
if dataset_key == "conv_ai_2":
dataset = load_dataset(dataset_key, trust_remote_code=True)
else:
dataset = load_dataset(dataset_key)
print(f"[data_prep] Successfully loaded '{dataset_key}'. Writing to {output_path}...")
with open(output_path, "w", encoding="utf-8") as fout:
if dataset_key == "persona_chat":
for split in ["train", "valid"]:
for conv in dataset[split]:
for line in conv["dialog"]:
text = line.strip()
if text:
fout.write(text.replace("\n", " ") + "\n")
else: # conv_ai_2
for split in ["train", "valid"]:
for item in dataset[split]:
# conv_ai_2 has a field named "dialog"
if "dialog" in item:
for line in item["dialog"]:
text = line.strip()
if text:
fout.write(text.replace("\n", " ") + "\n")
elif "utterance" in item:
text = item["utterance"].strip()
if text:
fout.write(text.replace("\n", " ") + "\n")
print(f"[data_prep] Finished writing PersonaChat ({dataset_key}) to {output_path}.")
return
except Exception as e:
print(f"[data_prep] Failed '{dataset_key}': {e}", file=sys.stderr)
# Try next key
print("[data_prep] Could not load PersonaChat under any key. Skipping PersonaChat.", file=sys.stderr)
def main():
cornell_path = CONV_DIR / "cornell_movie_dialogs.txt"
persona_path = CONV_DIR / "persona_chat.txt"
# 1) Prepare Cornell Movie-Dialogs: try ConvoKit, then manual
if not cornell_path.is_file():
ok = prepare_cornell_via_convokit(str(cornell_path))
if not ok:
prepare_cornell_manual(str(cornell_path))
else:
print(f"[data_prep] Skipping Cornell: '{cornell_path}' already exists.")
# 2) Prepare PersonaChat
if not persona_path.is_file():
prepare_personachat(str(persona_path))
else:
print(f"[data_prep] Skipping PersonaChat: '{persona_path}' already exists.")
print("[data_prep] All done. You can now include data/conversational/ in Nora's training.")
if __name__ == "__main__":
main()

107
knowledge_retriever.py Normal file
View File

@ -0,0 +1,107 @@
# knowledge_retriever.py
import os
import re
import requests
from bs4 import BeautifulSoup
import logging
# Where to dump new “.txt” files scraped from the web
SCRAPE_DIR = "data/books/scraped"
# Simple ratelimiter to avoid hammering any one domain
import time
_last_request_time = {}
def fetch_url(url: str, min_interval: float = 1.0) -> str:
"""
Fetch a URLs HTML, enforcing at least `min_interval` seconds between requests
to the same domain. Returns HTML string or empty string on failure.
"""
from urllib.parse import urlparse
domain = urlparse(url).netloc
now = time.time()
last = _last_request_time.get(domain, 0)
wait = min_interval - (now - last)
if wait > 0:
time.sleep(wait)
headers = {"User-Agent": "NoraScraper/1.0 (+https://your_project_url)"}
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
_last_request_time[domain] = time.time()
return response.text
except Exception as e:
logging.error(f"Error fetching {url}: {e}")
return ""
def clean_html(html: str) -> str:
"""
Strip scripts, styles, and tags; return plain text.
"""
soup = BeautifulSoup(html, "html.parser")
# remove scripts and styles
for tag in soup(["script", "style", "noscript"]):
tag.decompose()
text = soup.get_text(separator="\n")
# collapse multiple blank lines
lines = [line.strip() for line in text.splitlines()]
text = "\n".join([line for line in lines if line])
return text
def save_text(content: str, title: str):
"""
Save content to a UTF-8 .txt file under SCRAPE_DIR. Filename is derived from title.
"""
os.makedirs(SCRAPE_DIR, exist_ok=True)
# sanitize title → filename
safe = re.sub(r"[^0-9a-zA-Z_\-]", "_", title)
fname = f"{safe[:50]}.txt"
path = os.path.join(SCRAPE_DIR, fname)
with open(path, "w", encoding="utf-8") as f:
f.write(content)
logging.info(f"Saved scraped page to {path}")
def scrape_and_store(url: str):
"""
High-level function: fetches URL, cleans HTML, extracts a title, and saves to a .txt.
"""
html = fetch_url(url)
if not html:
return False
text = clean_html(html)
# extract <title> if present
title = ""
m = re.search(r"<title>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
if m:
title = m.group(1).strip()
else:
title = url
save_text(text, title)
return True
# Example usage:
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python knowledge_retriever.py <url1> [<url2> ...]")
sys.exit(1)
for link in sys.argv[1:]:
success = scrape_and_store(link)
if success:
print(f"Scraped: {link}")
else:
print(f"Failed to scrape: {link}")

359
main.py
View File

@ -1,89 +1,310 @@
""" # main.py
main.py
Orchestrates the entire Nora project:
- Parses arguments
- Builds or loads tokenizer
- Constructs dataset & dataloader
- Instantiates the model
- Sets up optimizer, scheduler
- Calls train()
"""
import os import os
import torch import asyncio
import logging import logging
import subprocess
import re
from urllib.parse import urlparse
import discord
import torch
from discord import Intents
from config import get_config from config import get_config
from tokenizer import CharTokenizer from tokenizer import CharTokenizer
from model import NoraTransformerLM, top_k_logits
from data_loader import get_dataloader from data_loader import get_dataloader
from model import NoraTransformerLM
from train import train
from utils import setup_logging, load_checkpoint, save_checkpoint from utils import setup_logging, load_checkpoint, save_checkpoint
from knowledge_retriever import fetch_url, clean_html, save_text
import persona_manager # <— PERSONA CHANGE
torch.backends.cudnn.benchmark = True # Logging setup
torch.backends.cudnn.enabled = True logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# Keep the SCRAPE regex as before
SCRAPE_PATTERN = re.compile(r"<<SCRAPE:(https?://[^\s>]+)>>", re.IGNORECASE)
def main(): # ---------------------------------------------------
args = get_config() # 1) Build or Reload Noras Model & Tokenizer
# ---------------------------------------------------
def build_nora(config, device):
"""
- Loads tokenizer
- Instantiates NoraTransformerLM
- Loads latest checkpoint (if any)
"""
tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
# 1) Logging setup
log_file = os.path.join(args.checkpoint_dir, "train.log")
setup_logging(log_file)
logging.info(f"[main] Using device: {args.device}")
logging.info(f"[main] Config: {args}")
# 2) Tokenizer: if vocab exists, load; else build from data_dir
tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
# 3) DataLoader
dataloader = get_dataloader(
data_dir=args.data_dir,
tokenizer=tokenizer,
seq_length=args.seq_length,
batch_size=args.batch_size,
shuffle=True,
)
# 4) Model instantiation
model = NoraTransformerLM( model = NoraTransformerLM(
vocab_size=tokenizer.vocab_size(), vocab_size=tokenizer.vocab_size(),
d_model=args.d_model, d_model=config.d_model,
nhead=args.nhead, nhead=config.nhead,
num_layers=args.num_layers, num_layers=config.num_layers,
dim_feedforward=args.dim_feedforward, dim_feedforward=config.dim_feedforward,
dropout=args.dropout, dropout=config.dropout,
max_seq_len=args.seq_length, max_seq_len=config.seq_length,
) )
# 5) Optimizer & scheduler (linear warmup + decay) # Find latest checkpoint
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9) latest_ckpt = None
if os.path.isdir(config.checkpoint_dir):
ckpts = [
f
for f in os.listdir(config.checkpoint_dir)
if f.startswith("nora_step_") and f.endswith(".pt")
]
if ckpts:
latest_ckpt = sorted(
ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
def lr_lambda(current_step): if latest_ckpt:
# Linear warmup for first warmup_steps, then decay with 1/sqrt(step) ckpt_path = os.path.join(config.checkpoint_dir, latest_ckpt)
if current_step < args.warmup_steps: state = torch.load(ckpt_path, map_location="cpu")
return float(current_step) / float(max(1, args.warmup_steps)) model.load_state_dict(state["model_state_dict"])
return (args.warmup_steps ** 0.5) * float(current_step ** -0.5) logger.info(f"Loaded Nora from checkpoint {latest_ckpt}")
else:
logger.warning("No checkpoints found. Nora starts untrained.")
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) model.to(device)
model.eval()
# 6) Check for existing checkpoint to resume return model, tokenizer
start_step = 0
ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
if ckpts:
latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
start_step = load_checkpoint(latest_ckpt, model, optimizer)
# 7) Begin training
try:
train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
except Exception as e:
logging.exception("[main] Exception during training")
raise e
# ---------------------------------------------------
# 2) Autoregressive Sampling Function
# ---------------------------------------------------
def generate_text(
model: NoraTransformerLM,
tokenizer: CharTokenizer,
device: str,
prompt: str,
max_length: int = 128,
temperature: float = 1.0,
top_k: int = 50,
) -> str:
"""
Wrapper around model.generate. Returns raw generated string.
"""
return model.generate(
tokenizer=tokenizer,
device=device,
prompt=prompt,
max_length=max_length,
temperature=temperature,
top_k=top_k,
)
# ---------------------------------------------------
# 3) SCRAPE Directive Handler (same as before)
# ---------------------------------------------------
async def handle_scrape_directives(text: str, data_dir: str) -> bool:
urls = set(m.group(1) for m in SCRAPE_PATTERN.finditer(text))
if not urls:
return False
for url in urls:
logger.info(f"Directive: Scraping {url}")
html = fetch_url(url)
if not html:
logger.error(f"Failed to fetch {url}")
continue
plain = clean_html(html)
title = ""
m = re.search(r"<title>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
if m:
title = m.group(1).strip()[:50]
else:
title = urlparse(url).netloc
save_text(plain, title)
return True
# ---------------------------------------------------
# 4) Discord Client (“Nora” class) with Persona
# ---------------------------------------------------
class Nora(discord.Client):
def __init__(self, model, tokenizer, config, device):
intents = Intents.default()
intents.messages = True
intents.message_content = True
super().__init__(intents=intents)
self.model = model
self.tokenizer = tokenizer
self.config = config
self.device = device
# history per channel (last 5 user/nora pairs)
self.history = {}
# Load Noras persona from disk (its guaranteed to exist by startup)
with open(persona_manager.PERSONA_PATH, "r", encoding="utf-8") as f:
self.persona_text = f.read()
async def on_ready(self):
logger.info(f"Logged in as {self.user} (ID: {self.user.id})")
logger.info("Nora is online and ready to be herself.")
# Background task: reload model if a new checkpoint shows up
self.loop.create_task(self._reload_model_periodically())
async def _reload_model_periodically(self, interval: int = 600):
"""
Every `interval` seconds, check for newer checkpoint & reload.
"""
while True:
await asyncio.sleep(interval)
ckpts = [
f
for f in os.listdir(self.config.checkpoint_dir)
if f.startswith("nora_step_") and f.endswith(".pt")
]
if not ckpts:
continue
latest = sorted(
ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
ckpt_path = os.path.join(self.config.checkpoint_dir, latest)
state = torch.load(ckpt_path, map_location="cpu")
self.model.load_state_dict(state["model_state_dict"])
self.model.to(self.device)
self.model.eval()
logger.info(f"Reloaded Noras model from {latest}")
async def on_message(self, message: discord.Message):
# 1) Ignore self-messages
if message.author.id == self.user.id:
return
content = message.content.strip()
prompt = None
# 2) If in DM, treat entire content as prompt
if isinstance(message.channel, discord.DMChannel):
prompt = content
# Also allow “update persona” in DM to regenerate persona file
if content.lower().startswith("update persona"):
# Regenerate persona asynchronously
new_persona = await persona_manager.maybe_update_persona(
self.model, self.tokenizer, self.device
)
self.persona_text = new_persona
await message.channel.send(
"I have rewritten my persona. Thank you! ❤️"
)
return
# 3) Otherwise (guild), require mention or “nora,” prefix
else:
if self.user.mention in content:
prompt = content.replace(self.user.mention, "").strip()
elif content.lower().startswith("nora,"):
prompt = content[len("nora,"):].strip()
else:
return
if not prompt:
return # e.g. user only said “Nora,” with no text after
# 4) Show typing indicator if in a guild text channel
if isinstance(message.channel, discord.TextChannel):
await message.channel.trigger_typing()
# 5) Build the full prompt: persona + history + users prompt
chan_id = str(message.channel.id)
history = self.history.get(chan_id, [])
prompt_lines = []
# 5.1) Insert Noras persona (so she “speaks as herself”)
prompt_lines.append(self.persona_text)
prompt_lines.append("") # blank line
# 5.2) Insert up to the last 4 exchanges
for user_msg, nora_msg in history[-4:]:
prompt_lines.append(f"User: {user_msg}")
prompt_lines.append(f"Nora: {nora_msg}")
prompt_lines.append("")
# 5.3) Finally, insert the new user prompt
prompt_lines.append(f"User: {prompt}")
prompt_lines.append("Nora:")
conversation_prompt = "\n".join(prompt_lines)
# 6) Generate Noras reply (tighter sampling: temp=0.8, top_k=20)
try:
raw = await asyncio.to_thread(
self.model.generate,
self.tokenizer,
self.device,
conversation_prompt,
self.config.seq_length,
0.8, # temperature
20, # top_k
)
except Exception:
logger.exception("Error in Nora.generate()")
await message.channel.send("😔 Sorry, I hit an error trying to think.")
return
# 7) Extract just Noras reply text
if "Nora:" in raw:
nora_reply = raw.split("Nora:")[-1].strip()
else:
nora_reply = raw[len(conversation_prompt) :].strip()
# 8) Handle <<SCRAPE:…>> if present
did_scrape = await handle_scrape_directives(raw, self.config.data_dir)
if did_scrape:
logger.info("Detected scrape directive. Triggering incremental retrain.")
subprocess.Popen(
["python", "pretrain.py", "--resume"], # note: pretrain.py was your old main.py
cwd=os.getcwd(),
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
nora_reply += "\n\n*I found some new things and am updating myself…*"
# 9) Save to history and prune to the last 5
self.history.setdefault(chan_id, []).append((prompt, nora_reply))
self.history[chan_id] = self.history[chan_id][-5 :]
# 10) Send Noras reply
await message.channel.send(nora_reply)
# ---------------------------------------------------
# 4) Entrypoint
# ---------------------------------------------------
if __name__ == "__main__": if __name__ == "__main__":
main() config = get_config()
device = config.device
# 4.1) Build/load model & tokenizer
model, tokenizer = build_nora(config, device)
# 4.2) Ensure a persona exists—if not, generate one now
persona_manager.ensure_persona_file(model, tokenizer, device)
# 4.3) After that, we can proceed to start the agent
discord_token = os.getenv("DISCORD_TOKEN")
if not discord_token:
logger.error("Please set DISCORD_TOKEN in your environment.")
exit(1)
# Enable CuDNN autotune if on CUDA
if device.startswith("cuda"):
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
bot = Nora(model=model, tokenizer=tokenizer, config=config, device=device)
bot.run(discord_token)

View File

@ -7,9 +7,22 @@ No pretrained weights—everything is initialized randomly.
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import math import math
def top_k_logits(logits: torch.Tensor, k: int):
"""
Zero out all but the top k logits in each row; return modified logits.
logits: (vocab_size,)
"""
if k == 0:
return logits
topk_vals, _ = torch.topk(logits, k)
min_topk = topk_vals[-1]
return torch.where(logits < min_topk, torch.full_like(logits, -1e10), logits)
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 10_000): def __init__(self, d_model: int, max_len: int = 10_000):
super().__init__() super().__init__()
@ -98,3 +111,41 @@ class NoraTransformerLM(nn.Module):
x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model) x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model)
logits = self.fc_out(x) # (batch_size, seq_length, vocab_size) logits = self.fc_out(x) # (batch_size, seq_length, vocab_size)
return logits return logits
def generate(
self,
tokenizer,
device: str,
prompt: str,
max_length: int = 128,
temperature: float = 1.0,
top_k: int = 50,
) -> str:
"""
Autoregressively generate text from a prompt.
- tokenizer: CharTokenizer (for encode/decode)
- device: "cuda" or "cpu"
- prompt: initial string
- max_length: total tokens to generate (including prompt)
- temperature: scales logits before softmax
- top_k: keep only top_k logits at each step
"""
self.eval()
input_ids = tokenizer.encode(prompt)
input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
generated = input_ids.clone() # shape (1, seq_len)
for _ in range(max_length - input_ids.size(1)):
# 1) trim to last seq_length tokens if longer than context window
if generated.size(1) > self.pos_encoder.pe.size(1):
generated = generated[:, -self.pos_encoder.pe.size(1) :]
with torch.no_grad():
logits = self.forward(generated) # (1, seq_len, vocab_size)
next_token_logits = logits[0, -1, :] / temperature
filtered_logits = top_k_logits(next_token_logits, k=top_k)
probs = F.softmax(filtered_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1) # (1,)
generated = torch.cat([generated, next_id.unsqueeze(0)], dim=1)
output_ids = generated.squeeze(0).tolist()
return tokenizer.decode(output_ids)

2
nora_persona.txt Normal file
View File

@ -0,0 +1,2 @@
if you conduct, ens you heel to anyone. As for we know the bold child
like me!”

88
persona_manager.py Normal file
View File

@ -0,0 +1,88 @@
# persona_manager.py
import os
import torch
import asyncio
import re
from tokenizer import CharTokenizer
from model import NoraTransformerLM
PERSONA_PATH = "nora_persona.txt"
# 1) A meta-prompt that explicitly tells Nora to invent a persona and avoid quoting:
PERSONA_META_PROMPT = (
"Below, Nora will create a brandnew identity for herself. "
"She must NOT quote from any books or passages she has read. "
"Instead, she should invent her own style, voice, quirks, and personality traits as if she were a completely new person. "
"Her persona should be flirty, playful, curious, and speak in full sentences. "
"Write at least three paragraphs in your own words.\n\n"
"Nora, please invent and write your complete persona now:\n\nNora:"
)
async def generate_persona(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str) -> str:
"""
Ask Nora to write out her own, original persona, avoiding any verbatim quotes.
Returns the raw generated text.
"""
# Well ask for up to 512 tokens, with higher temperature and top_p sampling.
# That combination tends to produce more creative, lessmemorizable text.
raw = await asyncio.to_thread(
model.generate,
tokenizer,
device,
PERSONA_META_PROMPT,
512, # allow several paragraphs
1.2, # higher temperature for more creativity
0 # top_k=0 means no fixed-k; well apply top_p filtering instead
)
# At this point, “raw” may include the word “Nora:” etc. Strip everything before “Nora:”
if "Nora:" in raw:
persona_text = raw.split("Nora:")[-1].strip()
else:
persona_text = raw.strip()
# Now apply a simple postfilter: remove any long spans that match exact sequences in the book corpus.
# This is optional but helps ensure she didnt copy large chunks verbatim. We check for 6+ character substrings
# appearing more than once in her output.
def remove_long_quotes(text: str) -> str:
filtered = text
# find any substring of length ≥6 that appears twice; well just guess shes quoting if its repeated.
for match in re.finditer(r"\b[\w',]{6,}\b", text):
substr = match.group(0)
if filtered.count(substr) > 1:
filtered = filtered.replace(substr, "[…]")
return filtered
persona_text = remove_long_quotes(persona_text)
return persona_text
def ensure_persona_file(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str):
"""
If nora_persona.txt does not exist, generate one (ensuring originality).
"""
if os.path.isfile(PERSONA_PATH):
return
print("[persona] No persona found. Generating a new, original persona…")
persona_text = asyncio.run(generate_persona(model, tokenizer, device))
# Save to disk
with open(PERSONA_PATH, "w", encoding="utf-8") as f:
f.write(persona_text)
print(f"[persona] Wrote new persona to {PERSONA_PATH}.")
async def maybe_update_persona(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str):
"""
Regenerate Noras persona if she requests it, overwriting the file.
"""
print("[persona] Updating persona at Nora's request…")
persona_text = await generate_persona(model, tokenizer, device)
with open(PERSONA_PATH, "w", encoding="utf-8") as f:
f.write(persona_text)
print(f"[persona] Updated persona in {PERSONA_PATH}.")
return persona_text

89
pretrain.py Normal file
View File

@ -0,0 +1,89 @@
"""
pretrain.py
Orchestrates the entire Nora project:
- Parses arguments
- Builds or loads tokenizer
- Constructs dataset & dataloader
- Instantiates the model
- Sets up optimizer, scheduler
- Calls train()
"""
import os
import torch
import logging
from config import get_config
from tokenizer import CharTokenizer
from data_loader import get_dataloader
from model import NoraTransformerLM
from train import train
from utils import setup_logging, load_checkpoint, save_checkpoint
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
def pretrain():
args = get_config()
# 1) Logging setup
log_file = os.path.join(args.checkpoint_dir, "train.log")
setup_logging(log_file)
logging.info(f"[pretrain] Using device: {args.device}")
logging.info(f"[pretrain] Config: {args}")
# 2) Tokenizer: if vocab exists, load; else build from data_dir
tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
# 3) DataLoader
dataloader = get_dataloader(
data_dir=args.data_dir,
tokenizer=tokenizer,
seq_length=args.seq_length,
batch_size=args.batch_size,
shuffle=True,
)
# 4) Model instantiation
model = NoraTransformerLM(
vocab_size=tokenizer.vocab_size(),
d_model=args.d_model,
nhead=args.nhead,
num_layers=args.num_layers,
dim_feedforward=args.dim_feedforward,
dropout=args.dropout,
max_seq_len=args.seq_length,
)
# 5) Optimizer & scheduler (linear warmup + decay)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
def lr_lambda(current_step):
# Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
if current_step < args.warmup_steps:
return float(current_step) / float(max(1, args.warmup_steps))
return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# 6) Check for existing checkpoint to resume
start_step = 0
ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
if ckpts:
latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
start_step = load_checkpoint(latest_ckpt, model, optimizer)
# 7) Begin training
try:
train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
except Exception as e:
logging.exception("[main] Exception during training")
raise e
if __name__ == "__main__":
pretrain()

23
requirements.txt Normal file
View File

@ -0,0 +1,23 @@
# Core ML framework
# Data loading / utilities
tqdm>=4.64.0
numpy>=1.22.0
# HuggingFace Datasets (PersonaChat, etc.)
datasets>=2.0.0
# ConvoKit (for Cornell Movie-Dialogs + Unsloth)
convokit>=2.0.0
# Discord bot
discord.py>=2.0.0
# Web scraping
beautifulsoup4>=4.11.0
requests>=2.28.0
# If ConvoKit pulls in TensorFlow via Unsloth
tensorflow>=2.10.0

175
self_improve.py Normal file
View File

@ -0,0 +1,175 @@
# self_improve.py
import os
import subprocess
import tempfile
import shutil
import logging
import torch
from tokenizer import CharTokenizer
from model import NoraTransformerLM
from config import get_config
# ------------------------------------------------------
# 1) “Teacher”: Pose a codegeneration prompt to Nora
# ------------------------------------------------------
def propose_patch(model, tokenizer, device, prompt: str) -> str:
"""
Ask Nora to generate a code snippet given `prompt`.
e.g. prompt = "### FILE: knowledge_retriever.py\n# Add a new function clean_html(...) that..."
Returns the raw text (possibly including the prompt).
"""
raw = model.generate(
tokenizer=tokenizer,
device=device,
prompt=prompt,
max_length=512,
temperature=0.7,
top_k=50,
)
return raw
# ------------------------------------------------------
# 2) “Verifier” agent: sandbox + test
# ------------------------------------------------------
class CodeVerifier:
"""
Given a proposed code patch (as text), this class:
1. Writes it to a temporary file (or repo clone)
2. Runs Pythons syntax check (compile) and unit tests
3. Measures performance changes (e.g. run a small validation set through the model)
4. Returns True/False + log messages
"""
def __init__(self, repo_dir: str, test_command: str):
"""
- repo_dir: path to your Nora project root
- test_command: a shell command string to run unit tests, e.g. "pytest tests/"
"""
self.repo_dir = repo_dir
self.test_command = test_command
def verify_patch(self, rel_path: str, patch_code: str) -> bool:
"""
- rel_path: relative path inside repo where the patch should go, e.g. "knowledge_retriever.py"
- patch_code: entire contents of that file (not a diff).
Returns True if syntax + tests pass; False otherwise.
"""
# 1) Copy repo => temp dir
tmpdir = tempfile.mkdtemp(prefix="nora_verify_")
try:
shutil.copytree(self.repo_dir, os.path.join(tmpdir, "repo"), dirs_exist_ok=True)
target_file = os.path.join(tmpdir, "repo", rel_path)
# 2) Write patch_code to target_file
with open(target_file, "w", encoding="utf-8") as f:
f.write(patch_code)
# 3) Syntax check (try compiling)
try:
compile(patch_code, target_file, "exec")
except SyntaxError as se:
logging.error(f"Syntax error in patch: {se}")
return False
# 4) Run unit tests
result = subprocess.run(
self.test_command,
shell=True,
cwd=os.path.join(tmpdir, "repo"),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
if result.returncode != 0:
logging.error(f"Unit tests failed:\n{result.stdout}")
return False
# 5) (Optional) Performance check
# You could load the updated model and measure perplexity on a tiny validation set here.
# For now, we assume passing tests = “improvement.”
return True
finally:
shutil.rmtree(tmpdir)
def merge_patch(self, rel_path: str, patch_code: str) -> None:
"""
Overwrite the real file in `repo_dir/rel_path` with patch_code,
then git-add and git-commit (you can also automate a PR).
"""
target_file = os.path.join(self.repo_dir, rel_path)
with open(target_file, "w", encoding="utf-8") as f:
f.write(patch_code)
# Example: git add + commit
subprocess.run(f"git add {rel_path}", shell=True, cwd=self.repo_dir)
subprocess.run(
f'git commit -m "Auto-update {rel_path} via Nora self-improve."',
shell=True,
cwd=self.repo_dir,
)
# ------------------------------------------------------
# 3) Main loop: ask → verify → merge (if good) → retrain
# ------------------------------------------------------
def self_improvement_cycle(repo_dir: str, device: str):
"""
Example cycle:
1) Nora proposes a new helper in knowledge_retriever.py
2) Verifier checks syntax + tests
3) If ok, merge and trigger incremental retraining
"""
config = get_config()
tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
model = NoraTransformerLM(
vocab_size=tokenizer.vocab_size(),
d_model=config.d_model,
nhead=config.nhead,
num_layers=config.num_layers,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout,
max_seq_len=config.seq_length,
)
# Load latest checkpoint
ckpts = []
if os.path.isdir(config.checkpoint_dir):
ckpts = [
f
for f in os.listdir(config.checkpoint_dir)
if f.startswith("nora_step_") and f.endswith(".pt")
]
if ckpts:
latest = sorted(ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1]
state = torch.load(os.path.join(config.checkpoint_dir, latest), map_location="cpu")
model.load_state_dict(state["model_state_dict"])
model.to(device)
model.eval()
verifier = CodeVerifier(repo_dir=repo_dir, test_command="pytest tests/")
# Example prompt: ask Nora to extend knowledge_retriever.py
prompt = (
"### FILE: knowledge_retriever.py\n"
"# Add a function clean_html(html: str) -> str that strips tags and scripts.\n"
"# Use BeautifulSoup if available. Return plain text.\n\n"
"### START\n"
"def clean_html(html: str) -> str:\n"
)
raw_patch = propose_patch(model, tokenizer, device, prompt)
# Extract everything from “def clean_html” to end of function (simple heuristic)
# In practice, youd parse until the next “\n\n” or rely on indentation.
patch_code = raw_patch # for now, assume raw_patch is the full file contents
# Verify
if verifier.verify_patch("knowledge_retriever.py", patch_code):
logging.info("Patch verified. Merging into live code.")
verifier.merge_patch("knowledge_retriever.py", patch_code)
# Optionally: trigger incremental retraining here (e.g. call train.py with --resume)
else:
logging.warning("Patch failed verification. Discarding.")

View File

@ -37,6 +37,13 @@ def train(
device = config.device device = config.device
model.to(device) model.to(device)
# ─── ensure optimizer state is on the same device ───
# (this moves any loaded CPU buffers for Adam/AdamW into CUDA)
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<pad>"]) criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<pad>"])
scaler = GradScaler() scaler = GradScaler()