Added another learning source for Nora. Also added the requirements.

This commit is contained in:
2025-06-09 14:25:11 -04:00
parent da23742671
commit 5d53ba7cb8
14 changed files with 1070 additions and 78 deletions

2
.gitignore vendored
View File

@ -196,3 +196,5 @@ cython_debug/
checkpoints/nora_step_*.pt
data/books
*.json
data/conversational/cornell_movie_dialogs.txt
.gitignore

View File

@ -67,7 +67,7 @@ def get_config():
parser.add_argument(
"--batch_size",
type=int,
default=32,
default=512,
help="Batch size per training step.",
)
parser.add_argument(

View File

View File

@ -21,15 +21,25 @@ class TextDataset(Dataset):
self.seq_length = seq_length
self.tokenizer = tokenizer
# Read and concatenate all text files into one long string
# Read and concatenate all .txt files under two folders:
# - data/books/
# - data/conversational/
texts = []
for root, _, files in os.walk(data_dir):
for fname in files:
if not fname.lower().endswith(".txt"):
continue
path = os.path.join(root, fname)
with open(path, "r", encoding="utf-8", errors="ignore") as f:
texts.append(f.read())
# If data_dir is a single path, we still look for a sibling "conversational" folder
conversational_dir = os.path.join(os.path.dirname(data_dir), "conversational")
# Gather all folders to walk
dirs_to_walk = [data_dir]
if os.path.isdir(conversational_dir):
dirs_to_walk.append(conversational_dir)
for d in dirs_to_walk:
for root, _, files in os.walk(d):
for fname in files:
if not fname.lower().endswith(".txt"):
continue
path = os.path.join(root, fname)
with open(path, "r", encoding="utf-8", errors="ignore") as f:
texts.append(f.read())
full_text = "\n".join(texts)
token_ids = self.tokenizer.encode(full_text)

217
data_prep.py Normal file
View File

@ -0,0 +1,217 @@
#!/usr/bin/env python3
"""
data_prep.py
1) Attempts to download Cornell Movie-Dialogs via ConvoKit (key: "movie-corpus").
- If ConvoKit/Unsloth fails, falls back to manual ZIP download/extraction.
2) Attempts to download PersonaChat via Hugging Face Datasets:
- First tries "persona_chat" (older key).
- If that fails, tries "conv_ai_2" (alias).
- Catches any exception to skip gracefully.
3) Writes each utterance to:
data/conversational/cornell_movie_dialogs.txt
data/conversational/persona_chat.txt
After running, youll have:
data/
├── books/ (your original Gutenberg .txt files)
└── conversational/
├── cornell_movie_dialogs.txt
└── persona_chat.txt
Then retrain or fine-tune Nora on data/books/ + data/conversational/.
"""
import os
import sys
import zipfile
import tempfile
import urllib.request
from pathlib import Path
# === 1) Attempt to import ConvoKit for Cornell Movie-Dialogs ===
USE_CONVOKIT = True
try:
from convokit import Corpus, download as convokit_download
except ImportError:
USE_CONVOKIT = False
# === 2) Attempt to import Hugging Face Datasets ===
HAS_DATASETS = True
try:
from datasets import load_dataset
except ImportError:
HAS_DATASETS = False
# Directory for conversational data
CONV_DIR = Path("data/conversational")
CONV_DIR.mkdir(parents=True, exist_ok=True)
# Official ZIP URL (fallback) for Cornell Movie-Dialogs
CORNELL_ZIP_URL = "https://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip"
def install_package(pkg_name: str):
"""
Installs a Python package using the same Python interpreter,
wrapping the path in quotes to handle spaces.
"""
python_executable = sys.executable
command = f"\"{python_executable}\" -m pip install {pkg_name}"
print(f"[data_prep] Installing package: {pkg_name}")
os.system(command)
def prepare_cornell_via_convokit(output_path: str) -> bool:
"""
Try to download Cornell Movie-Dialogs via ConvoKit (key: "movie-corpus").
Returns True if successful, False otherwise.
"""
if not USE_CONVOKIT:
print("[data_prep] ConvoKit not installed; skipping ConvoKit path.")
return False
print("[data_prep] Attempting to download Cornell Movie-Dialogs via ConvoKit...")
try:
corpus = Corpus(filename=convokit_download("movie-corpus"))
with open(output_path, "w", encoding="utf-8") as fout:
for utt in corpus.iter_utterances():
text = utt.text.strip()
if text:
fout.write(text.replace("\n", " ") + "\n")
print(f"[data_prep] Wrote Cornell Movie-Dialogs to {output_path} (via ConvoKit).")
return True
except NotImplementedError as nie:
# Typically due to Unsloth error if GPU unsupported
print("[data_prep] ConvoKit raised NotImplementedError (Unsloth/GPU issue).")
print(f"[data_prep] Error: {nie}")
return False
except Exception as e:
print("[data_prep] ConvoKit path failed with exception:", file=sys.stderr)
print(e, file=sys.stderr)
return False
def prepare_cornell_manual(output_path: str):
"""
Fallback: Download Cornell ZIP manually, extract movie_lines.txt,
and write all utterances to output_path.
"""
print("[data_prep] Falling back to manual download of Cornell Movie-Dialogs...")
with tempfile.TemporaryDirectory() as tmpdir:
zip_path = os.path.join(tmpdir, "cornell.zip")
try:
print(f"[data_prep] Downloading from {CORNELL_ZIP_URL} ...")
urllib.request.urlretrieve(CORNELL_ZIP_URL, zip_path)
except Exception as e:
print(f"[data_prep] Error downloading Cornell corpus: {e}", file=sys.stderr)
return
try:
with zipfile.ZipFile(zip_path, "r") as z:
member_name = None
for name in z.namelist():
if name.endswith("movie_lines.txt"):
member_name = name
break
if member_name is None:
print("[data_prep] movie_lines.txt not found in ZIP.", file=sys.stderr)
return
z.extract(member_name, path=tmpdir)
extracted_path = os.path.join(tmpdir, member_name)
except Exception as e:
print(f"[data_prep] Error extracting ZIP: {e}", file=sys.stderr)
return
try:
with open(extracted_path, "r", encoding="iso-8859-1", errors="ignore") as fin, open(
output_path, "w", encoding="utf-8"
) as fout:
for line in fin:
parts = line.split(" +++$+++ ")
if len(parts) == 5:
text = parts[-1].strip()
if text:
fout.write(text.replace("\n", " ") + "\n")
except Exception as e:
print(f"[data_prep] Error parsing movie_lines.txt: {e}", file=sys.stderr)
return
print(f"[data_prep] Wrote Cornell Movie-Dialogs to {output_path} (manual).")
def prepare_personachat(output_path: str):
"""
Attempt to download PersonaChat via Hugging Face Datasets.
Tries "persona_chat" and then "conv_ai_2". Catches any exception.
"""
if not HAS_DATASETS:
install_package("datasets")
global load_dataset
from datasets import load_dataset
# Now we have it
for dataset_key in ["persona_chat", "conv_ai_2"]:
try:
print(f"[data_prep] Attempting to load '{dataset_key}' via Hugging Face Datasets...")
if dataset_key == "conv_ai_2":
dataset = load_dataset(dataset_key, trust_remote_code=True)
else:
dataset = load_dataset(dataset_key)
print(f"[data_prep] Successfully loaded '{dataset_key}'. Writing to {output_path}...")
with open(output_path, "w", encoding="utf-8") as fout:
if dataset_key == "persona_chat":
for split in ["train", "valid"]:
for conv in dataset[split]:
for line in conv["dialog"]:
text = line.strip()
if text:
fout.write(text.replace("\n", " ") + "\n")
else: # conv_ai_2
for split in ["train", "valid"]:
for item in dataset[split]:
# conv_ai_2 has a field named "dialog"
if "dialog" in item:
for line in item["dialog"]:
text = line.strip()
if text:
fout.write(text.replace("\n", " ") + "\n")
elif "utterance" in item:
text = item["utterance"].strip()
if text:
fout.write(text.replace("\n", " ") + "\n")
print(f"[data_prep] Finished writing PersonaChat ({dataset_key}) to {output_path}.")
return
except Exception as e:
print(f"[data_prep] Failed '{dataset_key}': {e}", file=sys.stderr)
# Try next key
print("[data_prep] Could not load PersonaChat under any key. Skipping PersonaChat.", file=sys.stderr)
def main():
cornell_path = CONV_DIR / "cornell_movie_dialogs.txt"
persona_path = CONV_DIR / "persona_chat.txt"
# 1) Prepare Cornell Movie-Dialogs: try ConvoKit, then manual
if not cornell_path.is_file():
ok = prepare_cornell_via_convokit(str(cornell_path))
if not ok:
prepare_cornell_manual(str(cornell_path))
else:
print(f"[data_prep] Skipping Cornell: '{cornell_path}' already exists.")
# 2) Prepare PersonaChat
if not persona_path.is_file():
prepare_personachat(str(persona_path))
else:
print(f"[data_prep] Skipping PersonaChat: '{persona_path}' already exists.")
print("[data_prep] All done. You can now include data/conversational/ in Nora's training.")
if __name__ == "__main__":
main()

107
knowledge_retriever.py Normal file
View File

@ -0,0 +1,107 @@
# knowledge_retriever.py
import os
import re
import requests
from bs4 import BeautifulSoup
import logging
# Where to dump new “.txt” files scraped from the web
SCRAPE_DIR = "data/books/scraped"
# Simple ratelimiter to avoid hammering any one domain
import time
_last_request_time = {}
def fetch_url(url: str, min_interval: float = 1.0) -> str:
"""
Fetch a URLs HTML, enforcing at least `min_interval` seconds between requests
to the same domain. Returns HTML string or empty string on failure.
"""
from urllib.parse import urlparse
domain = urlparse(url).netloc
now = time.time()
last = _last_request_time.get(domain, 0)
wait = min_interval - (now - last)
if wait > 0:
time.sleep(wait)
headers = {"User-Agent": "NoraScraper/1.0 (+https://your_project_url)"}
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
_last_request_time[domain] = time.time()
return response.text
except Exception as e:
logging.error(f"Error fetching {url}: {e}")
return ""
def clean_html(html: str) -> str:
"""
Strip scripts, styles, and tags; return plain text.
"""
soup = BeautifulSoup(html, "html.parser")
# remove scripts and styles
for tag in soup(["script", "style", "noscript"]):
tag.decompose()
text = soup.get_text(separator="\n")
# collapse multiple blank lines
lines = [line.strip() for line in text.splitlines()]
text = "\n".join([line for line in lines if line])
return text
def save_text(content: str, title: str):
"""
Save content to a UTF-8 .txt file under SCRAPE_DIR. Filename is derived from title.
"""
os.makedirs(SCRAPE_DIR, exist_ok=True)
# sanitize title → filename
safe = re.sub(r"[^0-9a-zA-Z_\-]", "_", title)
fname = f"{safe[:50]}.txt"
path = os.path.join(SCRAPE_DIR, fname)
with open(path, "w", encoding="utf-8") as f:
f.write(content)
logging.info(f"Saved scraped page to {path}")
def scrape_and_store(url: str):
"""
High-level function: fetches URL, cleans HTML, extracts a title, and saves to a .txt.
"""
html = fetch_url(url)
if not html:
return False
text = clean_html(html)
# extract <title> if present
title = ""
m = re.search(r"<title>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
if m:
title = m.group(1).strip()
else:
title = url
save_text(text, title)
return True
# Example usage:
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python knowledge_retriever.py <url1> [<url2> ...]")
sys.exit(1)
for link in sys.argv[1:]:
success = scrape_and_store(link)
if success:
print(f"Scraped: {link}")
else:
print(f"Failed to scrape: {link}")

359
main.py
View File

@ -1,89 +1,310 @@
"""
main.py
Orchestrates the entire Nora project:
- Parses arguments
- Builds or loads tokenizer
- Constructs dataset & dataloader
- Instantiates the model
- Sets up optimizer, scheduler
- Calls train()
"""
# main.py
import os
import torch
import asyncio
import logging
import subprocess
import re
from urllib.parse import urlparse
import discord
import torch
from discord import Intents
from config import get_config
from tokenizer import CharTokenizer
from model import NoraTransformerLM, top_k_logits
from data_loader import get_dataloader
from model import NoraTransformerLM
from train import train
from utils import setup_logging, load_checkpoint, save_checkpoint
from knowledge_retriever import fetch_url, clean_html, save_text
import persona_manager # <— PERSONA CHANGE
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
# Logging setup
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# Keep the SCRAPE regex as before
SCRAPE_PATTERN = re.compile(r"<<SCRAPE:(https?://[^\s>]+)>>", re.IGNORECASE)
def main():
args = get_config()
# ---------------------------------------------------
# 1) Build or Reload Noras Model & Tokenizer
# ---------------------------------------------------
def build_nora(config, device):
"""
- Loads tokenizer
- Instantiates NoraTransformerLM
- Loads latest checkpoint (if any)
"""
tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
# 1) Logging setup
log_file = os.path.join(args.checkpoint_dir, "train.log")
setup_logging(log_file)
logging.info(f"[main] Using device: {args.device}")
logging.info(f"[main] Config: {args}")
# 2) Tokenizer: if vocab exists, load; else build from data_dir
tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
# 3) DataLoader
dataloader = get_dataloader(
data_dir=args.data_dir,
tokenizer=tokenizer,
seq_length=args.seq_length,
batch_size=args.batch_size,
shuffle=True,
)
# 4) Model instantiation
model = NoraTransformerLM(
vocab_size=tokenizer.vocab_size(),
d_model=args.d_model,
nhead=args.nhead,
num_layers=args.num_layers,
dim_feedforward=args.dim_feedforward,
dropout=args.dropout,
max_seq_len=args.seq_length,
d_model=config.d_model,
nhead=config.nhead,
num_layers=config.num_layers,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout,
max_seq_len=config.seq_length,
)
# 5) Optimizer & scheduler (linear warmup + decay)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
# Find latest checkpoint
latest_ckpt = None
if os.path.isdir(config.checkpoint_dir):
ckpts = [
f
for f in os.listdir(config.checkpoint_dir)
if f.startswith("nora_step_") and f.endswith(".pt")
]
if ckpts:
latest_ckpt = sorted(
ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
def lr_lambda(current_step):
# Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
if current_step < args.warmup_steps:
return float(current_step) / float(max(1, args.warmup_steps))
return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
if latest_ckpt:
ckpt_path = os.path.join(config.checkpoint_dir, latest_ckpt)
state = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state["model_state_dict"])
logger.info(f"Loaded Nora from checkpoint {latest_ckpt}")
else:
logger.warning("No checkpoints found. Nora starts untrained.")
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# 6) Check for existing checkpoint to resume
start_step = 0
ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
if ckpts:
latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
start_step = load_checkpoint(latest_ckpt, model, optimizer)
# 7) Begin training
try:
train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
except Exception as e:
logging.exception("[main] Exception during training")
raise e
model.to(device)
model.eval()
return model, tokenizer
# ---------------------------------------------------
# 2) Autoregressive Sampling Function
# ---------------------------------------------------
def generate_text(
model: NoraTransformerLM,
tokenizer: CharTokenizer,
device: str,
prompt: str,
max_length: int = 128,
temperature: float = 1.0,
top_k: int = 50,
) -> str:
"""
Wrapper around model.generate. Returns raw generated string.
"""
return model.generate(
tokenizer=tokenizer,
device=device,
prompt=prompt,
max_length=max_length,
temperature=temperature,
top_k=top_k,
)
# ---------------------------------------------------
# 3) SCRAPE Directive Handler (same as before)
# ---------------------------------------------------
async def handle_scrape_directives(text: str, data_dir: str) -> bool:
urls = set(m.group(1) for m in SCRAPE_PATTERN.finditer(text))
if not urls:
return False
for url in urls:
logger.info(f"Directive: Scraping {url}")
html = fetch_url(url)
if not html:
logger.error(f"Failed to fetch {url}")
continue
plain = clean_html(html)
title = ""
m = re.search(r"<title>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
if m:
title = m.group(1).strip()[:50]
else:
title = urlparse(url).netloc
save_text(plain, title)
return True
# ---------------------------------------------------
# 4) Discord Client (“Nora” class) with Persona
# ---------------------------------------------------
class Nora(discord.Client):
def __init__(self, model, tokenizer, config, device):
intents = Intents.default()
intents.messages = True
intents.message_content = True
super().__init__(intents=intents)
self.model = model
self.tokenizer = tokenizer
self.config = config
self.device = device
# history per channel (last 5 user/nora pairs)
self.history = {}
# Load Noras persona from disk (its guaranteed to exist by startup)
with open(persona_manager.PERSONA_PATH, "r", encoding="utf-8") as f:
self.persona_text = f.read()
async def on_ready(self):
logger.info(f"Logged in as {self.user} (ID: {self.user.id})")
logger.info("Nora is online and ready to be herself.")
# Background task: reload model if a new checkpoint shows up
self.loop.create_task(self._reload_model_periodically())
async def _reload_model_periodically(self, interval: int = 600):
"""
Every `interval` seconds, check for newer checkpoint & reload.
"""
while True:
await asyncio.sleep(interval)
ckpts = [
f
for f in os.listdir(self.config.checkpoint_dir)
if f.startswith("nora_step_") and f.endswith(".pt")
]
if not ckpts:
continue
latest = sorted(
ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
ckpt_path = os.path.join(self.config.checkpoint_dir, latest)
state = torch.load(ckpt_path, map_location="cpu")
self.model.load_state_dict(state["model_state_dict"])
self.model.to(self.device)
self.model.eval()
logger.info(f"Reloaded Noras model from {latest}")
async def on_message(self, message: discord.Message):
# 1) Ignore self-messages
if message.author.id == self.user.id:
return
content = message.content.strip()
prompt = None
# 2) If in DM, treat entire content as prompt
if isinstance(message.channel, discord.DMChannel):
prompt = content
# Also allow “update persona” in DM to regenerate persona file
if content.lower().startswith("update persona"):
# Regenerate persona asynchronously
new_persona = await persona_manager.maybe_update_persona(
self.model, self.tokenizer, self.device
)
self.persona_text = new_persona
await message.channel.send(
"I have rewritten my persona. Thank you! ❤️"
)
return
# 3) Otherwise (guild), require mention or “nora,” prefix
else:
if self.user.mention in content:
prompt = content.replace(self.user.mention, "").strip()
elif content.lower().startswith("nora,"):
prompt = content[len("nora,"):].strip()
else:
return
if not prompt:
return # e.g. user only said “Nora,” with no text after
# 4) Show typing indicator if in a guild text channel
if isinstance(message.channel, discord.TextChannel):
await message.channel.trigger_typing()
# 5) Build the full prompt: persona + history + users prompt
chan_id = str(message.channel.id)
history = self.history.get(chan_id, [])
prompt_lines = []
# 5.1) Insert Noras persona (so she “speaks as herself”)
prompt_lines.append(self.persona_text)
prompt_lines.append("") # blank line
# 5.2) Insert up to the last 4 exchanges
for user_msg, nora_msg in history[-4:]:
prompt_lines.append(f"User: {user_msg}")
prompt_lines.append(f"Nora: {nora_msg}")
prompt_lines.append("")
# 5.3) Finally, insert the new user prompt
prompt_lines.append(f"User: {prompt}")
prompt_lines.append("Nora:")
conversation_prompt = "\n".join(prompt_lines)
# 6) Generate Noras reply (tighter sampling: temp=0.8, top_k=20)
try:
raw = await asyncio.to_thread(
self.model.generate,
self.tokenizer,
self.device,
conversation_prompt,
self.config.seq_length,
0.8, # temperature
20, # top_k
)
except Exception:
logger.exception("Error in Nora.generate()")
await message.channel.send("😔 Sorry, I hit an error trying to think.")
return
# 7) Extract just Noras reply text
if "Nora:" in raw:
nora_reply = raw.split("Nora:")[-1].strip()
else:
nora_reply = raw[len(conversation_prompt) :].strip()
# 8) Handle <<SCRAPE:…>> if present
did_scrape = await handle_scrape_directives(raw, self.config.data_dir)
if did_scrape:
logger.info("Detected scrape directive. Triggering incremental retrain.")
subprocess.Popen(
["python", "pretrain.py", "--resume"], # note: pretrain.py was your old main.py
cwd=os.getcwd(),
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
nora_reply += "\n\n*I found some new things and am updating myself…*"
# 9) Save to history and prune to the last 5
self.history.setdefault(chan_id, []).append((prompt, nora_reply))
self.history[chan_id] = self.history[chan_id][-5 :]
# 10) Send Noras reply
await message.channel.send(nora_reply)
# ---------------------------------------------------
# 4) Entrypoint
# ---------------------------------------------------
if __name__ == "__main__":
main()
config = get_config()
device = config.device
# 4.1) Build/load model & tokenizer
model, tokenizer = build_nora(config, device)
# 4.2) Ensure a persona exists—if not, generate one now
persona_manager.ensure_persona_file(model, tokenizer, device)
# 4.3) After that, we can proceed to start the agent
discord_token = os.getenv("DISCORD_TOKEN")
if not discord_token:
logger.error("Please set DISCORD_TOKEN in your environment.")
exit(1)
# Enable CuDNN autotune if on CUDA
if device.startswith("cuda"):
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
bot = Nora(model=model, tokenizer=tokenizer, config=config, device=device)
bot.run(discord_token)

View File

@ -7,9 +7,22 @@ No pretrained weights—everything is initialized randomly.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def top_k_logits(logits: torch.Tensor, k: int):
"""
Zero out all but the top k logits in each row; return modified logits.
logits: (vocab_size,)
"""
if k == 0:
return logits
topk_vals, _ = torch.topk(logits, k)
min_topk = topk_vals[-1]
return torch.where(logits < min_topk, torch.full_like(logits, -1e10), logits)
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 10_000):
super().__init__()
@ -98,3 +111,41 @@ class NoraTransformerLM(nn.Module):
x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model)
logits = self.fc_out(x) # (batch_size, seq_length, vocab_size)
return logits
def generate(
self,
tokenizer,
device: str,
prompt: str,
max_length: int = 128,
temperature: float = 1.0,
top_k: int = 50,
) -> str:
"""
Autoregressively generate text from a prompt.
- tokenizer: CharTokenizer (for encode/decode)
- device: "cuda" or "cpu"
- prompt: initial string
- max_length: total tokens to generate (including prompt)
- temperature: scales logits before softmax
- top_k: keep only top_k logits at each step
"""
self.eval()
input_ids = tokenizer.encode(prompt)
input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
generated = input_ids.clone() # shape (1, seq_len)
for _ in range(max_length - input_ids.size(1)):
# 1) trim to last seq_length tokens if longer than context window
if generated.size(1) > self.pos_encoder.pe.size(1):
generated = generated[:, -self.pos_encoder.pe.size(1) :]
with torch.no_grad():
logits = self.forward(generated) # (1, seq_len, vocab_size)
next_token_logits = logits[0, -1, :] / temperature
filtered_logits = top_k_logits(next_token_logits, k=top_k)
probs = F.softmax(filtered_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1) # (1,)
generated = torch.cat([generated, next_id.unsqueeze(0)], dim=1)
output_ids = generated.squeeze(0).tolist()
return tokenizer.decode(output_ids)

2
nora_persona.txt Normal file
View File

@ -0,0 +1,2 @@
if you conduct, ens you heel to anyone. As for we know the bold child
like me!”

88
persona_manager.py Normal file
View File

@ -0,0 +1,88 @@
# persona_manager.py
import os
import torch
import asyncio
import re
from tokenizer import CharTokenizer
from model import NoraTransformerLM
PERSONA_PATH = "nora_persona.txt"
# 1) A meta-prompt that explicitly tells Nora to invent a persona and avoid quoting:
PERSONA_META_PROMPT = (
"Below, Nora will create a brandnew identity for herself. "
"She must NOT quote from any books or passages she has read. "
"Instead, she should invent her own style, voice, quirks, and personality traits as if she were a completely new person. "
"Her persona should be flirty, playful, curious, and speak in full sentences. "
"Write at least three paragraphs in your own words.\n\n"
"Nora, please invent and write your complete persona now:\n\nNora:"
)
async def generate_persona(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str) -> str:
"""
Ask Nora to write out her own, original persona, avoiding any verbatim quotes.
Returns the raw generated text.
"""
# Well ask for up to 512 tokens, with higher temperature and top_p sampling.
# That combination tends to produce more creative, lessmemorizable text.
raw = await asyncio.to_thread(
model.generate,
tokenizer,
device,
PERSONA_META_PROMPT,
512, # allow several paragraphs
1.2, # higher temperature for more creativity
0 # top_k=0 means no fixed-k; well apply top_p filtering instead
)
# At this point, “raw” may include the word “Nora:” etc. Strip everything before “Nora:”
if "Nora:" in raw:
persona_text = raw.split("Nora:")[-1].strip()
else:
persona_text = raw.strip()
# Now apply a simple postfilter: remove any long spans that match exact sequences in the book corpus.
# This is optional but helps ensure she didnt copy large chunks verbatim. We check for 6+ character substrings
# appearing more than once in her output.
def remove_long_quotes(text: str) -> str:
filtered = text
# find any substring of length ≥6 that appears twice; well just guess shes quoting if its repeated.
for match in re.finditer(r"\b[\w',]{6,}\b", text):
substr = match.group(0)
if filtered.count(substr) > 1:
filtered = filtered.replace(substr, "[…]")
return filtered
persona_text = remove_long_quotes(persona_text)
return persona_text
def ensure_persona_file(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str):
"""
If nora_persona.txt does not exist, generate one (ensuring originality).
"""
if os.path.isfile(PERSONA_PATH):
return
print("[persona] No persona found. Generating a new, original persona…")
persona_text = asyncio.run(generate_persona(model, tokenizer, device))
# Save to disk
with open(PERSONA_PATH, "w", encoding="utf-8") as f:
f.write(persona_text)
print(f"[persona] Wrote new persona to {PERSONA_PATH}.")
async def maybe_update_persona(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str):
"""
Regenerate Noras persona if she requests it, overwriting the file.
"""
print("[persona] Updating persona at Nora's request…")
persona_text = await generate_persona(model, tokenizer, device)
with open(PERSONA_PATH, "w", encoding="utf-8") as f:
f.write(persona_text)
print(f"[persona] Updated persona in {PERSONA_PATH}.")
return persona_text

89
pretrain.py Normal file
View File

@ -0,0 +1,89 @@
"""
pretrain.py
Orchestrates the entire Nora project:
- Parses arguments
- Builds or loads tokenizer
- Constructs dataset & dataloader
- Instantiates the model
- Sets up optimizer, scheduler
- Calls train()
"""
import os
import torch
import logging
from config import get_config
from tokenizer import CharTokenizer
from data_loader import get_dataloader
from model import NoraTransformerLM
from train import train
from utils import setup_logging, load_checkpoint, save_checkpoint
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
def pretrain():
args = get_config()
# 1) Logging setup
log_file = os.path.join(args.checkpoint_dir, "train.log")
setup_logging(log_file)
logging.info(f"[pretrain] Using device: {args.device}")
logging.info(f"[pretrain] Config: {args}")
# 2) Tokenizer: if vocab exists, load; else build from data_dir
tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
# 3) DataLoader
dataloader = get_dataloader(
data_dir=args.data_dir,
tokenizer=tokenizer,
seq_length=args.seq_length,
batch_size=args.batch_size,
shuffle=True,
)
# 4) Model instantiation
model = NoraTransformerLM(
vocab_size=tokenizer.vocab_size(),
d_model=args.d_model,
nhead=args.nhead,
num_layers=args.num_layers,
dim_feedforward=args.dim_feedforward,
dropout=args.dropout,
max_seq_len=args.seq_length,
)
# 5) Optimizer & scheduler (linear warmup + decay)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
def lr_lambda(current_step):
# Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
if current_step < args.warmup_steps:
return float(current_step) / float(max(1, args.warmup_steps))
return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# 6) Check for existing checkpoint to resume
start_step = 0
ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
if ckpts:
latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
start_step = load_checkpoint(latest_ckpt, model, optimizer)
# 7) Begin training
try:
train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
except Exception as e:
logging.exception("[main] Exception during training")
raise e
if __name__ == "__main__":
pretrain()

23
requirements.txt Normal file
View File

@ -0,0 +1,23 @@
# Core ML framework
# Data loading / utilities
tqdm>=4.64.0
numpy>=1.22.0
# HuggingFace Datasets (PersonaChat, etc.)
datasets>=2.0.0
# ConvoKit (for Cornell Movie-Dialogs + Unsloth)
convokit>=2.0.0
# Discord bot
discord.py>=2.0.0
# Web scraping
beautifulsoup4>=4.11.0
requests>=2.28.0
# If ConvoKit pulls in TensorFlow via Unsloth
tensorflow>=2.10.0

175
self_improve.py Normal file
View File

@ -0,0 +1,175 @@
# self_improve.py
import os
import subprocess
import tempfile
import shutil
import logging
import torch
from tokenizer import CharTokenizer
from model import NoraTransformerLM
from config import get_config
# ------------------------------------------------------
# 1) “Teacher”: Pose a codegeneration prompt to Nora
# ------------------------------------------------------
def propose_patch(model, tokenizer, device, prompt: str) -> str:
"""
Ask Nora to generate a code snippet given `prompt`.
e.g. prompt = "### FILE: knowledge_retriever.py\n# Add a new function clean_html(...) that..."
Returns the raw text (possibly including the prompt).
"""
raw = model.generate(
tokenizer=tokenizer,
device=device,
prompt=prompt,
max_length=512,
temperature=0.7,
top_k=50,
)
return raw
# ------------------------------------------------------
# 2) “Verifier” agent: sandbox + test
# ------------------------------------------------------
class CodeVerifier:
"""
Given a proposed code patch (as text), this class:
1. Writes it to a temporary file (or repo clone)
2. Runs Pythons syntax check (compile) and unit tests
3. Measures performance changes (e.g. run a small validation set through the model)
4. Returns True/False + log messages
"""
def __init__(self, repo_dir: str, test_command: str):
"""
- repo_dir: path to your Nora project root
- test_command: a shell command string to run unit tests, e.g. "pytest tests/"
"""
self.repo_dir = repo_dir
self.test_command = test_command
def verify_patch(self, rel_path: str, patch_code: str) -> bool:
"""
- rel_path: relative path inside repo where the patch should go, e.g. "knowledge_retriever.py"
- patch_code: entire contents of that file (not a diff).
Returns True if syntax + tests pass; False otherwise.
"""
# 1) Copy repo => temp dir
tmpdir = tempfile.mkdtemp(prefix="nora_verify_")
try:
shutil.copytree(self.repo_dir, os.path.join(tmpdir, "repo"), dirs_exist_ok=True)
target_file = os.path.join(tmpdir, "repo", rel_path)
# 2) Write patch_code to target_file
with open(target_file, "w", encoding="utf-8") as f:
f.write(patch_code)
# 3) Syntax check (try compiling)
try:
compile(patch_code, target_file, "exec")
except SyntaxError as se:
logging.error(f"Syntax error in patch: {se}")
return False
# 4) Run unit tests
result = subprocess.run(
self.test_command,
shell=True,
cwd=os.path.join(tmpdir, "repo"),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
if result.returncode != 0:
logging.error(f"Unit tests failed:\n{result.stdout}")
return False
# 5) (Optional) Performance check
# You could load the updated model and measure perplexity on a tiny validation set here.
# For now, we assume passing tests = “improvement.”
return True
finally:
shutil.rmtree(tmpdir)
def merge_patch(self, rel_path: str, patch_code: str) -> None:
"""
Overwrite the real file in `repo_dir/rel_path` with patch_code,
then git-add and git-commit (you can also automate a PR).
"""
target_file = os.path.join(self.repo_dir, rel_path)
with open(target_file, "w", encoding="utf-8") as f:
f.write(patch_code)
# Example: git add + commit
subprocess.run(f"git add {rel_path}", shell=True, cwd=self.repo_dir)
subprocess.run(
f'git commit -m "Auto-update {rel_path} via Nora self-improve."',
shell=True,
cwd=self.repo_dir,
)
# ------------------------------------------------------
# 3) Main loop: ask → verify → merge (if good) → retrain
# ------------------------------------------------------
def self_improvement_cycle(repo_dir: str, device: str):
"""
Example cycle:
1) Nora proposes a new helper in knowledge_retriever.py
2) Verifier checks syntax + tests
3) If ok, merge and trigger incremental retraining
"""
config = get_config()
tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
model = NoraTransformerLM(
vocab_size=tokenizer.vocab_size(),
d_model=config.d_model,
nhead=config.nhead,
num_layers=config.num_layers,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout,
max_seq_len=config.seq_length,
)
# Load latest checkpoint
ckpts = []
if os.path.isdir(config.checkpoint_dir):
ckpts = [
f
for f in os.listdir(config.checkpoint_dir)
if f.startswith("nora_step_") and f.endswith(".pt")
]
if ckpts:
latest = sorted(ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1]
state = torch.load(os.path.join(config.checkpoint_dir, latest), map_location="cpu")
model.load_state_dict(state["model_state_dict"])
model.to(device)
model.eval()
verifier = CodeVerifier(repo_dir=repo_dir, test_command="pytest tests/")
# Example prompt: ask Nora to extend knowledge_retriever.py
prompt = (
"### FILE: knowledge_retriever.py\n"
"# Add a function clean_html(html: str) -> str that strips tags and scripts.\n"
"# Use BeautifulSoup if available. Return plain text.\n\n"
"### START\n"
"def clean_html(html: str) -> str:\n"
)
raw_patch = propose_patch(model, tokenizer, device, prompt)
# Extract everything from “def clean_html” to end of function (simple heuristic)
# In practice, youd parse until the next “\n\n” or rely on indentation.
patch_code = raw_patch # for now, assume raw_patch is the full file contents
# Verify
if verifier.verify_patch("knowledge_retriever.py", patch_code):
logging.info("Patch verified. Merging into live code.")
verifier.merge_patch("knowledge_retriever.py", patch_code)
# Optionally: trigger incremental retraining here (e.g. call train.py with --resume)
else:
logging.warning("Patch failed verification. Discarding.")

View File

@ -37,6 +37,13 @@ def train(
device = config.device
model.to(device)
# ─── ensure optimizer state is on the same device ───
# (this moves any loaded CPU buffers for Adam/AdamW into CUDA)
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<pad>"])
scaler = GradScaler()