Added another learning source for Nora. Also added the requirements.
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@ -196,3 +196,5 @@ cython_debug/
|
|||||||
checkpoints/nora_step_*.pt
|
checkpoints/nora_step_*.pt
|
||||||
data/books
|
data/books
|
||||||
*.json
|
*.json
|
||||||
|
data/conversational/cornell_movie_dialogs.txt
|
||||||
|
.gitignore
|
||||||
|
@ -67,7 +67,7 @@ def get_config():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch_size",
|
"--batch_size",
|
||||||
type=int,
|
type=int,
|
||||||
default=32,
|
default=512,
|
||||||
help="Batch size per training step.",
|
help="Batch size per training step.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
0
data/conversational/persona_chat.txt
Normal file
0
data/conversational/persona_chat.txt
Normal file
@ -21,15 +21,25 @@ class TextDataset(Dataset):
|
|||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
# Read and concatenate all text files into one long string
|
# Read and concatenate all .txt files under two folders:
|
||||||
|
# - data/books/
|
||||||
|
# - data/conversational/
|
||||||
texts = []
|
texts = []
|
||||||
for root, _, files in os.walk(data_dir):
|
# If data_dir is a single path, we still look for a sibling "conversational" folder
|
||||||
for fname in files:
|
conversational_dir = os.path.join(os.path.dirname(data_dir), "conversational")
|
||||||
if not fname.lower().endswith(".txt"):
|
# Gather all folders to walk
|
||||||
continue
|
dirs_to_walk = [data_dir]
|
||||||
path = os.path.join(root, fname)
|
if os.path.isdir(conversational_dir):
|
||||||
with open(path, "r", encoding="utf-8", errors="ignore") as f:
|
dirs_to_walk.append(conversational_dir)
|
||||||
texts.append(f.read())
|
|
||||||
|
for d in dirs_to_walk:
|
||||||
|
for root, _, files in os.walk(d):
|
||||||
|
for fname in files:
|
||||||
|
if not fname.lower().endswith(".txt"):
|
||||||
|
continue
|
||||||
|
path = os.path.join(root, fname)
|
||||||
|
with open(path, "r", encoding="utf-8", errors="ignore") as f:
|
||||||
|
texts.append(f.read())
|
||||||
full_text = "\n".join(texts)
|
full_text = "\n".join(texts)
|
||||||
token_ids = self.tokenizer.encode(full_text)
|
token_ids = self.tokenizer.encode(full_text)
|
||||||
|
|
||||||
|
217
data_prep.py
Normal file
217
data_prep.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
data_prep.py
|
||||||
|
|
||||||
|
1) Attempts to download Cornell Movie-Dialogs via ConvoKit (key: "movie-corpus").
|
||||||
|
- If ConvoKit/Unsloth fails, falls back to manual ZIP download/extraction.
|
||||||
|
|
||||||
|
2) Attempts to download PersonaChat via Hugging Face Datasets:
|
||||||
|
- First tries "persona_chat" (older key).
|
||||||
|
- If that fails, tries "conv_ai_2" (alias).
|
||||||
|
- Catches any exception to skip gracefully.
|
||||||
|
|
||||||
|
3) Writes each utterance to:
|
||||||
|
data/conversational/cornell_movie_dialogs.txt
|
||||||
|
data/conversational/persona_chat.txt
|
||||||
|
|
||||||
|
After running, you’ll have:
|
||||||
|
data/
|
||||||
|
├── books/ (your original Gutenberg .txt files)
|
||||||
|
└── conversational/
|
||||||
|
├── cornell_movie_dialogs.txt
|
||||||
|
└── persona_chat.txt
|
||||||
|
|
||||||
|
Then retrain or fine-tune Nora on data/books/ + data/conversational/.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import zipfile
|
||||||
|
import tempfile
|
||||||
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# === 1) Attempt to import ConvoKit for Cornell Movie-Dialogs ===
|
||||||
|
USE_CONVOKIT = True
|
||||||
|
try:
|
||||||
|
from convokit import Corpus, download as convokit_download
|
||||||
|
except ImportError:
|
||||||
|
USE_CONVOKIT = False
|
||||||
|
|
||||||
|
# === 2) Attempt to import Hugging Face Datasets ===
|
||||||
|
HAS_DATASETS = True
|
||||||
|
try:
|
||||||
|
from datasets import load_dataset
|
||||||
|
except ImportError:
|
||||||
|
HAS_DATASETS = False
|
||||||
|
|
||||||
|
# Directory for conversational data
|
||||||
|
CONV_DIR = Path("data/conversational")
|
||||||
|
CONV_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Official ZIP URL (fallback) for Cornell Movie-Dialogs
|
||||||
|
CORNELL_ZIP_URL = "https://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip"
|
||||||
|
|
||||||
|
|
||||||
|
def install_package(pkg_name: str):
|
||||||
|
"""
|
||||||
|
Installs a Python package using the same Python interpreter,
|
||||||
|
wrapping the path in quotes to handle spaces.
|
||||||
|
"""
|
||||||
|
python_executable = sys.executable
|
||||||
|
command = f"\"{python_executable}\" -m pip install {pkg_name}"
|
||||||
|
print(f"[data_prep] Installing package: {pkg_name}")
|
||||||
|
os.system(command)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_cornell_via_convokit(output_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Try to download Cornell Movie-Dialogs via ConvoKit (key: "movie-corpus").
|
||||||
|
Returns True if successful, False otherwise.
|
||||||
|
"""
|
||||||
|
if not USE_CONVOKIT:
|
||||||
|
print("[data_prep] ConvoKit not installed; skipping ConvoKit path.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("[data_prep] Attempting to download Cornell Movie-Dialogs via ConvoKit...")
|
||||||
|
try:
|
||||||
|
corpus = Corpus(filename=convokit_download("movie-corpus"))
|
||||||
|
with open(output_path, "w", encoding="utf-8") as fout:
|
||||||
|
for utt in corpus.iter_utterances():
|
||||||
|
text = utt.text.strip()
|
||||||
|
if text:
|
||||||
|
fout.write(text.replace("\n", " ") + "\n")
|
||||||
|
print(f"[data_prep] Wrote Cornell Movie-Dialogs to {output_path} (via ConvoKit).")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except NotImplementedError as nie:
|
||||||
|
# Typically due to Unsloth error if GPU unsupported
|
||||||
|
print("[data_prep] ConvoKit raised NotImplementedError (Unsloth/GPU issue).")
|
||||||
|
print(f"[data_prep] Error: {nie}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("[data_prep] ConvoKit path failed with exception:", file=sys.stderr)
|
||||||
|
print(e, file=sys.stderr)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_cornell_manual(output_path: str):
|
||||||
|
"""
|
||||||
|
Fallback: Download Cornell ZIP manually, extract movie_lines.txt,
|
||||||
|
and write all utterances to output_path.
|
||||||
|
"""
|
||||||
|
print("[data_prep] Falling back to manual download of Cornell Movie-Dialogs...")
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
zip_path = os.path.join(tmpdir, "cornell.zip")
|
||||||
|
try:
|
||||||
|
print(f"[data_prep] Downloading from {CORNELL_ZIP_URL} ...")
|
||||||
|
urllib.request.urlretrieve(CORNELL_ZIP_URL, zip_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[data_prep] Error downloading Cornell corpus: {e}", file=sys.stderr)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(zip_path, "r") as z:
|
||||||
|
member_name = None
|
||||||
|
for name in z.namelist():
|
||||||
|
if name.endswith("movie_lines.txt"):
|
||||||
|
member_name = name
|
||||||
|
break
|
||||||
|
if member_name is None:
|
||||||
|
print("[data_prep] movie_lines.txt not found in ZIP.", file=sys.stderr)
|
||||||
|
return
|
||||||
|
z.extract(member_name, path=tmpdir)
|
||||||
|
extracted_path = os.path.join(tmpdir, member_name)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[data_prep] Error extracting ZIP: {e}", file=sys.stderr)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(extracted_path, "r", encoding="iso-8859-1", errors="ignore") as fin, open(
|
||||||
|
output_path, "w", encoding="utf-8"
|
||||||
|
) as fout:
|
||||||
|
for line in fin:
|
||||||
|
parts = line.split(" +++$+++ ")
|
||||||
|
if len(parts) == 5:
|
||||||
|
text = parts[-1].strip()
|
||||||
|
if text:
|
||||||
|
fout.write(text.replace("\n", " ") + "\n")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[data_prep] Error parsing movie_lines.txt: {e}", file=sys.stderr)
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"[data_prep] Wrote Cornell Movie-Dialogs to {output_path} (manual).")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_personachat(output_path: str):
|
||||||
|
"""
|
||||||
|
Attempt to download PersonaChat via Hugging Face Datasets.
|
||||||
|
Tries "persona_chat" and then "conv_ai_2". Catches any exception.
|
||||||
|
"""
|
||||||
|
if not HAS_DATASETS:
|
||||||
|
install_package("datasets")
|
||||||
|
global load_dataset
|
||||||
|
from datasets import load_dataset
|
||||||
|
# Now we have it
|
||||||
|
for dataset_key in ["persona_chat", "conv_ai_2"]:
|
||||||
|
try:
|
||||||
|
print(f"[data_prep] Attempting to load '{dataset_key}' via Hugging Face Datasets...")
|
||||||
|
if dataset_key == "conv_ai_2":
|
||||||
|
dataset = load_dataset(dataset_key, trust_remote_code=True)
|
||||||
|
else:
|
||||||
|
dataset = load_dataset(dataset_key)
|
||||||
|
print(f"[data_prep] Successfully loaded '{dataset_key}'. Writing to {output_path}...")
|
||||||
|
with open(output_path, "w", encoding="utf-8") as fout:
|
||||||
|
if dataset_key == "persona_chat":
|
||||||
|
for split in ["train", "valid"]:
|
||||||
|
for conv in dataset[split]:
|
||||||
|
for line in conv["dialog"]:
|
||||||
|
text = line.strip()
|
||||||
|
if text:
|
||||||
|
fout.write(text.replace("\n", " ") + "\n")
|
||||||
|
else: # conv_ai_2
|
||||||
|
for split in ["train", "valid"]:
|
||||||
|
for item in dataset[split]:
|
||||||
|
# conv_ai_2 has a field named "dialog"
|
||||||
|
if "dialog" in item:
|
||||||
|
for line in item["dialog"]:
|
||||||
|
text = line.strip()
|
||||||
|
if text:
|
||||||
|
fout.write(text.replace("\n", " ") + "\n")
|
||||||
|
elif "utterance" in item:
|
||||||
|
text = item["utterance"].strip()
|
||||||
|
if text:
|
||||||
|
fout.write(text.replace("\n", " ") + "\n")
|
||||||
|
print(f"[data_prep] Finished writing PersonaChat ({dataset_key}) to {output_path}.")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[data_prep] Failed '{dataset_key}': {e}", file=sys.stderr)
|
||||||
|
# Try next key
|
||||||
|
|
||||||
|
print("[data_prep] Could not load PersonaChat under any key. Skipping PersonaChat.", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cornell_path = CONV_DIR / "cornell_movie_dialogs.txt"
|
||||||
|
persona_path = CONV_DIR / "persona_chat.txt"
|
||||||
|
|
||||||
|
# 1) Prepare Cornell Movie-Dialogs: try ConvoKit, then manual
|
||||||
|
if not cornell_path.is_file():
|
||||||
|
ok = prepare_cornell_via_convokit(str(cornell_path))
|
||||||
|
if not ok:
|
||||||
|
prepare_cornell_manual(str(cornell_path))
|
||||||
|
else:
|
||||||
|
print(f"[data_prep] Skipping Cornell: '{cornell_path}' already exists.")
|
||||||
|
|
||||||
|
# 2) Prepare PersonaChat
|
||||||
|
if not persona_path.is_file():
|
||||||
|
prepare_personachat(str(persona_path))
|
||||||
|
else:
|
||||||
|
print(f"[data_prep] Skipping PersonaChat: '{persona_path}' already exists.")
|
||||||
|
|
||||||
|
print("[data_prep] All done. You can now include data/conversational/ in Nora's training.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
107
knowledge_retriever.py
Normal file
107
knowledge_retriever.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
# knowledge_retriever.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import requests
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Where to dump new “.txt” files scraped from the web
|
||||||
|
SCRAPE_DIR = "data/books/scraped"
|
||||||
|
|
||||||
|
# Simple rate‐limiter to avoid hammering any one domain
|
||||||
|
import time
|
||||||
|
_last_request_time = {}
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_url(url: str, min_interval: float = 1.0) -> str:
|
||||||
|
"""
|
||||||
|
Fetch a URL’s HTML, enforcing at least `min_interval` seconds between requests
|
||||||
|
to the same domain. Returns HTML string or empty string on failure.
|
||||||
|
"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
domain = urlparse(url).netloc
|
||||||
|
now = time.time()
|
||||||
|
last = _last_request_time.get(domain, 0)
|
||||||
|
wait = min_interval - (now - last)
|
||||||
|
if wait > 0:
|
||||||
|
time.sleep(wait)
|
||||||
|
|
||||||
|
headers = {"User-Agent": "NoraScraper/1.0 (+https://your_project_url)"}
|
||||||
|
try:
|
||||||
|
response = requests.get(url, headers=headers, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
_last_request_time[domain] = time.time()
|
||||||
|
return response.text
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error fetching {url}: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def clean_html(html: str) -> str:
|
||||||
|
"""
|
||||||
|
Strip scripts, styles, and tags; return plain text.
|
||||||
|
"""
|
||||||
|
soup = BeautifulSoup(html, "html.parser")
|
||||||
|
|
||||||
|
# remove scripts and styles
|
||||||
|
for tag in soup(["script", "style", "noscript"]):
|
||||||
|
tag.decompose()
|
||||||
|
|
||||||
|
text = soup.get_text(separator="\n")
|
||||||
|
# collapse multiple blank lines
|
||||||
|
lines = [line.strip() for line in text.splitlines()]
|
||||||
|
text = "\n".join([line for line in lines if line])
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def save_text(content: str, title: str):
|
||||||
|
"""
|
||||||
|
Save content to a UTF-8 .txt file under SCRAPE_DIR. Filename is derived from title.
|
||||||
|
"""
|
||||||
|
os.makedirs(SCRAPE_DIR, exist_ok=True)
|
||||||
|
# sanitize title → filename
|
||||||
|
safe = re.sub(r"[^0-9a-zA-Z_\-]", "_", title)
|
||||||
|
fname = f"{safe[:50]}.txt"
|
||||||
|
path = os.path.join(SCRAPE_DIR, fname)
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(content)
|
||||||
|
logging.info(f"Saved scraped page to {path}")
|
||||||
|
|
||||||
|
|
||||||
|
def scrape_and_store(url: str):
|
||||||
|
"""
|
||||||
|
High-level function: fetches URL, cleans HTML, extracts a title, and saves to a .txt.
|
||||||
|
"""
|
||||||
|
html = fetch_url(url)
|
||||||
|
if not html:
|
||||||
|
return False
|
||||||
|
|
||||||
|
text = clean_html(html)
|
||||||
|
# extract <title> if present
|
||||||
|
title = ""
|
||||||
|
m = re.search(r"<title>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
|
||||||
|
if m:
|
||||||
|
title = m.group(1).strip()
|
||||||
|
else:
|
||||||
|
title = url
|
||||||
|
|
||||||
|
save_text(text, title)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: python knowledge_retriever.py <url1> [<url2> ...]")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
for link in sys.argv[1:]:
|
||||||
|
success = scrape_and_store(link)
|
||||||
|
if success:
|
||||||
|
print(f"Scraped: {link}")
|
||||||
|
else:
|
||||||
|
print(f"Failed to scrape: {link}")
|
359
main.py
359
main.py
@ -1,89 +1,310 @@
|
|||||||
"""
|
# main.py
|
||||||
main.py
|
|
||||||
|
|
||||||
Orchestrates the entire Nora project:
|
|
||||||
- Parses arguments
|
|
||||||
- Builds or loads tokenizer
|
|
||||||
- Constructs dataset & dataloader
|
|
||||||
- Instantiates the model
|
|
||||||
- Sets up optimizer, scheduler
|
|
||||||
- Calls train()
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import re
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import discord
|
||||||
|
import torch
|
||||||
|
from discord import Intents
|
||||||
|
|
||||||
from config import get_config
|
from config import get_config
|
||||||
from tokenizer import CharTokenizer
|
from tokenizer import CharTokenizer
|
||||||
|
from model import NoraTransformerLM, top_k_logits
|
||||||
from data_loader import get_dataloader
|
from data_loader import get_dataloader
|
||||||
from model import NoraTransformerLM
|
|
||||||
from train import train
|
|
||||||
from utils import setup_logging, load_checkpoint, save_checkpoint
|
from utils import setup_logging, load_checkpoint, save_checkpoint
|
||||||
|
from knowledge_retriever import fetch_url, clean_html, save_text
|
||||||
|
import persona_manager # <— PERSONA CHANGE
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
# Logging setup
|
||||||
torch.backends.cudnn.enabled = True
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="[%(asctime)s] [%(levelname)s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Keep the SCRAPE regex as before
|
||||||
|
SCRAPE_PATTERN = re.compile(r"<<SCRAPE:(https?://[^\s>]+)>>", re.IGNORECASE)
|
||||||
|
|
||||||
def main():
|
# ---------------------------------------------------
|
||||||
args = get_config()
|
# 1) Build or Reload Nora’s Model & Tokenizer
|
||||||
|
# ---------------------------------------------------
|
||||||
|
def build_nora(config, device):
|
||||||
|
"""
|
||||||
|
- Loads tokenizer
|
||||||
|
- Instantiates NoraTransformerLM
|
||||||
|
- Loads latest checkpoint (if any)
|
||||||
|
"""
|
||||||
|
tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
|
||||||
|
|
||||||
# 1) Logging setup
|
|
||||||
log_file = os.path.join(args.checkpoint_dir, "train.log")
|
|
||||||
setup_logging(log_file)
|
|
||||||
|
|
||||||
logging.info(f"[main] Using device: {args.device}")
|
|
||||||
logging.info(f"[main] Config: {args}")
|
|
||||||
|
|
||||||
# 2) Tokenizer: if vocab exists, load; else build from data_dir
|
|
||||||
tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
|
|
||||||
|
|
||||||
# 3) DataLoader
|
|
||||||
dataloader = get_dataloader(
|
|
||||||
data_dir=args.data_dir,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
seq_length=args.seq_length,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4) Model instantiation
|
|
||||||
model = NoraTransformerLM(
|
model = NoraTransformerLM(
|
||||||
vocab_size=tokenizer.vocab_size(),
|
vocab_size=tokenizer.vocab_size(),
|
||||||
d_model=args.d_model,
|
d_model=config.d_model,
|
||||||
nhead=args.nhead,
|
nhead=config.nhead,
|
||||||
num_layers=args.num_layers,
|
num_layers=config.num_layers,
|
||||||
dim_feedforward=args.dim_feedforward,
|
dim_feedforward=config.dim_feedforward,
|
||||||
dropout=args.dropout,
|
dropout=config.dropout,
|
||||||
max_seq_len=args.seq_length,
|
max_seq_len=config.seq_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5) Optimizer & scheduler (linear warmup + decay)
|
# Find latest checkpoint
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
|
latest_ckpt = None
|
||||||
|
if os.path.isdir(config.checkpoint_dir):
|
||||||
|
ckpts = [
|
||||||
|
f
|
||||||
|
for f in os.listdir(config.checkpoint_dir)
|
||||||
|
if f.startswith("nora_step_") and f.endswith(".pt")
|
||||||
|
]
|
||||||
|
if ckpts:
|
||||||
|
latest_ckpt = sorted(
|
||||||
|
ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
|
||||||
|
)[-1]
|
||||||
|
|
||||||
def lr_lambda(current_step):
|
if latest_ckpt:
|
||||||
# Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
|
ckpt_path = os.path.join(config.checkpoint_dir, latest_ckpt)
|
||||||
if current_step < args.warmup_steps:
|
state = torch.load(ckpt_path, map_location="cpu")
|
||||||
return float(current_step) / float(max(1, args.warmup_steps))
|
model.load_state_dict(state["model_state_dict"])
|
||||||
return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
|
logger.info(f"Loaded Nora from checkpoint {latest_ckpt}")
|
||||||
|
else:
|
||||||
|
logger.warning("No checkpoints found. Nora starts untrained.")
|
||||||
|
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
# 6) Check for existing checkpoint to resume
|
return model, tokenizer
|
||||||
start_step = 0
|
|
||||||
ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
|
|
||||||
ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
|
|
||||||
if ckpts:
|
|
||||||
latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
|
|
||||||
logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
|
|
||||||
start_step = load_checkpoint(latest_ckpt, model, optimizer)
|
|
||||||
|
|
||||||
# 7) Begin training
|
|
||||||
try:
|
|
||||||
train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("[main] Exception during training")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------
|
||||||
|
# 2) Autoregressive Sampling Function
|
||||||
|
# ---------------------------------------------------
|
||||||
|
def generate_text(
|
||||||
|
model: NoraTransformerLM,
|
||||||
|
tokenizer: CharTokenizer,
|
||||||
|
device: str,
|
||||||
|
prompt: str,
|
||||||
|
max_length: int = 128,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Wrapper around model.generate. Returns raw generated string.
|
||||||
|
"""
|
||||||
|
return model.generate(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device=device,
|
||||||
|
prompt=prompt,
|
||||||
|
max_length=max_length,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------
|
||||||
|
# 3) SCRAPE Directive Handler (same as before)
|
||||||
|
# ---------------------------------------------------
|
||||||
|
async def handle_scrape_directives(text: str, data_dir: str) -> bool:
|
||||||
|
urls = set(m.group(1) for m in SCRAPE_PATTERN.finditer(text))
|
||||||
|
if not urls:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for url in urls:
|
||||||
|
logger.info(f"Directive: Scraping {url}")
|
||||||
|
html = fetch_url(url)
|
||||||
|
if not html:
|
||||||
|
logger.error(f"Failed to fetch {url}")
|
||||||
|
continue
|
||||||
|
plain = clean_html(html)
|
||||||
|
title = ""
|
||||||
|
m = re.search(r"<title>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
|
||||||
|
if m:
|
||||||
|
title = m.group(1).strip()[:50]
|
||||||
|
else:
|
||||||
|
title = urlparse(url).netloc
|
||||||
|
save_text(plain, title)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------
|
||||||
|
# 4) Discord Client (“Nora” class) with Persona
|
||||||
|
# ---------------------------------------------------
|
||||||
|
class Nora(discord.Client):
|
||||||
|
def __init__(self, model, tokenizer, config, device):
|
||||||
|
intents = Intents.default()
|
||||||
|
intents.messages = True
|
||||||
|
intents.message_content = True
|
||||||
|
super().__init__(intents=intents)
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.config = config
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# history per channel (last 5 user/nora pairs)
|
||||||
|
self.history = {}
|
||||||
|
|
||||||
|
# Load Nora’s persona from disk (it’s guaranteed to exist by startup)
|
||||||
|
with open(persona_manager.PERSONA_PATH, "r", encoding="utf-8") as f:
|
||||||
|
self.persona_text = f.read()
|
||||||
|
|
||||||
|
async def on_ready(self):
|
||||||
|
logger.info(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||||
|
logger.info("Nora is online and ready to be herself.")
|
||||||
|
|
||||||
|
# Background task: reload model if a new checkpoint shows up
|
||||||
|
self.loop.create_task(self._reload_model_periodically())
|
||||||
|
|
||||||
|
async def _reload_model_periodically(self, interval: int = 600):
|
||||||
|
"""
|
||||||
|
Every `interval` seconds, check for newer checkpoint & reload.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
ckpts = [
|
||||||
|
f
|
||||||
|
for f in os.listdir(self.config.checkpoint_dir)
|
||||||
|
if f.startswith("nora_step_") and f.endswith(".pt")
|
||||||
|
]
|
||||||
|
if not ckpts:
|
||||||
|
continue
|
||||||
|
latest = sorted(
|
||||||
|
ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
|
||||||
|
)[-1]
|
||||||
|
ckpt_path = os.path.join(self.config.checkpoint_dir, latest)
|
||||||
|
state = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
self.model.load_state_dict(state["model_state_dict"])
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
logger.info(f"Reloaded Nora’s model from {latest}")
|
||||||
|
|
||||||
|
async def on_message(self, message: discord.Message):
|
||||||
|
# 1) Ignore self-messages
|
||||||
|
if message.author.id == self.user.id:
|
||||||
|
return
|
||||||
|
|
||||||
|
content = message.content.strip()
|
||||||
|
prompt = None
|
||||||
|
|
||||||
|
# 2) If in DM, treat entire content as prompt
|
||||||
|
if isinstance(message.channel, discord.DMChannel):
|
||||||
|
prompt = content
|
||||||
|
|
||||||
|
# Also allow “update persona” in DM to regenerate persona file
|
||||||
|
if content.lower().startswith("update persona"):
|
||||||
|
# Regenerate persona asynchronously
|
||||||
|
new_persona = await persona_manager.maybe_update_persona(
|
||||||
|
self.model, self.tokenizer, self.device
|
||||||
|
)
|
||||||
|
self.persona_text = new_persona
|
||||||
|
await message.channel.send(
|
||||||
|
"I have rewritten my persona. Thank you! ❤️"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3) Otherwise (guild), require mention or “nora,” prefix
|
||||||
|
else:
|
||||||
|
if self.user.mention in content:
|
||||||
|
prompt = content.replace(self.user.mention, "").strip()
|
||||||
|
elif content.lower().startswith("nora,"):
|
||||||
|
prompt = content[len("nora,"):].strip()
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not prompt:
|
||||||
|
return # e.g. user only said “Nora,” with no text after
|
||||||
|
|
||||||
|
# 4) Show typing indicator if in a guild text channel
|
||||||
|
if isinstance(message.channel, discord.TextChannel):
|
||||||
|
await message.channel.trigger_typing()
|
||||||
|
|
||||||
|
# 5) Build the full prompt: persona + history + user’s prompt
|
||||||
|
chan_id = str(message.channel.id)
|
||||||
|
history = self.history.get(chan_id, [])
|
||||||
|
|
||||||
|
prompt_lines = []
|
||||||
|
# 5.1) Insert Nora’s persona (so she “speaks as herself”)
|
||||||
|
prompt_lines.append(self.persona_text)
|
||||||
|
prompt_lines.append("") # blank line
|
||||||
|
|
||||||
|
# 5.2) Insert up to the last 4 exchanges
|
||||||
|
for user_msg, nora_msg in history[-4:]:
|
||||||
|
prompt_lines.append(f"User: {user_msg}")
|
||||||
|
prompt_lines.append(f"Nora: {nora_msg}")
|
||||||
|
prompt_lines.append("")
|
||||||
|
|
||||||
|
# 5.3) Finally, insert the new user prompt
|
||||||
|
prompt_lines.append(f"User: {prompt}")
|
||||||
|
prompt_lines.append("Nora:")
|
||||||
|
|
||||||
|
conversation_prompt = "\n".join(prompt_lines)
|
||||||
|
|
||||||
|
# 6) Generate Nora’s reply (tighter sampling: temp=0.8, top_k=20)
|
||||||
|
try:
|
||||||
|
raw = await asyncio.to_thread(
|
||||||
|
self.model.generate,
|
||||||
|
self.tokenizer,
|
||||||
|
self.device,
|
||||||
|
conversation_prompt,
|
||||||
|
self.config.seq_length,
|
||||||
|
0.8, # temperature
|
||||||
|
20, # top_k
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error in Nora.generate()")
|
||||||
|
await message.channel.send("😔 Sorry, I hit an error trying to think.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 7) Extract just Nora’s reply text
|
||||||
|
if "Nora:" in raw:
|
||||||
|
nora_reply = raw.split("Nora:")[-1].strip()
|
||||||
|
else:
|
||||||
|
nora_reply = raw[len(conversation_prompt) :].strip()
|
||||||
|
|
||||||
|
# 8) Handle <<SCRAPE:…>> if present
|
||||||
|
did_scrape = await handle_scrape_directives(raw, self.config.data_dir)
|
||||||
|
if did_scrape:
|
||||||
|
logger.info("Detected scrape directive. Triggering incremental retrain.")
|
||||||
|
subprocess.Popen(
|
||||||
|
["python", "pretrain.py", "--resume"], # note: pretrain.py was your old main.py
|
||||||
|
cwd=os.getcwd(),
|
||||||
|
stdout=subprocess.DEVNULL,
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
nora_reply += "\n\n*I found some new things and am updating myself…*"
|
||||||
|
|
||||||
|
# 9) Save to history and prune to the last 5
|
||||||
|
self.history.setdefault(chan_id, []).append((prompt, nora_reply))
|
||||||
|
self.history[chan_id] = self.history[chan_id][-5 :]
|
||||||
|
|
||||||
|
# 10) Send Nora’s reply
|
||||||
|
await message.channel.send(nora_reply)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------
|
||||||
|
# 4) Entrypoint
|
||||||
|
# ---------------------------------------------------
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
config = get_config()
|
||||||
|
device = config.device
|
||||||
|
|
||||||
|
# 4.1) Build/load model & tokenizer
|
||||||
|
model, tokenizer = build_nora(config, device)
|
||||||
|
|
||||||
|
# 4.2) Ensure a persona exists—if not, generate one now
|
||||||
|
persona_manager.ensure_persona_file(model, tokenizer, device)
|
||||||
|
|
||||||
|
# 4.3) After that, we can proceed to start the agent
|
||||||
|
discord_token = os.getenv("DISCORD_TOKEN")
|
||||||
|
if not discord_token:
|
||||||
|
logger.error("Please set DISCORD_TOKEN in your environment.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Enable CuDNN autotune if on CUDA
|
||||||
|
if device.startswith("cuda"):
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
|
||||||
|
bot = Nora(model=model, tokenizer=tokenizer, config=config, device=device)
|
||||||
|
bot.run(discord_token)
|
||||||
|
51
model.py
51
model.py
@ -7,9 +7,22 @@ No pretrained weights—everything is initialized randomly.
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def top_k_logits(logits: torch.Tensor, k: int):
|
||||||
|
"""
|
||||||
|
Zero out all but the top k logits in each row; return modified logits.
|
||||||
|
logits: (vocab_size,)
|
||||||
|
"""
|
||||||
|
if k == 0:
|
||||||
|
return logits
|
||||||
|
topk_vals, _ = torch.topk(logits, k)
|
||||||
|
min_topk = topk_vals[-1]
|
||||||
|
return torch.where(logits < min_topk, torch.full_like(logits, -1e10), logits)
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoding(nn.Module):
|
class PositionalEncoding(nn.Module):
|
||||||
def __init__(self, d_model: int, max_len: int = 10_000):
|
def __init__(self, d_model: int, max_len: int = 10_000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -98,3 +111,41 @@ class NoraTransformerLM(nn.Module):
|
|||||||
x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model)
|
x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model)
|
||||||
logits = self.fc_out(x) # (batch_size, seq_length, vocab_size)
|
logits = self.fc_out(x) # (batch_size, seq_length, vocab_size)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
device: str,
|
||||||
|
prompt: str,
|
||||||
|
max_length: int = 128,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Autoregressively generate text from a prompt.
|
||||||
|
- tokenizer: CharTokenizer (for encode/decode)
|
||||||
|
- device: "cuda" or "cpu"
|
||||||
|
- prompt: initial string
|
||||||
|
- max_length: total tokens to generate (including prompt)
|
||||||
|
- temperature: scales logits before softmax
|
||||||
|
- top_k: keep only top_k logits at each step
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
input_ids = tokenizer.encode(prompt)
|
||||||
|
input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
|
||||||
|
generated = input_ids.clone() # shape (1, seq_len)
|
||||||
|
for _ in range(max_length - input_ids.size(1)):
|
||||||
|
# 1) trim to last seq_length tokens if longer than context window
|
||||||
|
if generated.size(1) > self.pos_encoder.pe.size(1):
|
||||||
|
generated = generated[:, -self.pos_encoder.pe.size(1) :]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.forward(generated) # (1, seq_len, vocab_size)
|
||||||
|
next_token_logits = logits[0, -1, :] / temperature
|
||||||
|
filtered_logits = top_k_logits(next_token_logits, k=top_k)
|
||||||
|
probs = F.softmax(filtered_logits, dim=-1)
|
||||||
|
next_id = torch.multinomial(probs, num_samples=1) # (1,)
|
||||||
|
generated = torch.cat([generated, next_id.unsqueeze(0)], dim=1)
|
||||||
|
|
||||||
|
output_ids = generated.squeeze(0).tolist()
|
||||||
|
return tokenizer.decode(output_ids)
|
2
nora_persona.txt
Normal file
2
nora_persona.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
if you conduct, ens you heel to anyone. As for we know the bold child
|
||||||
|
like me!”
|
88
persona_manager.py
Normal file
88
persona_manager.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# persona_manager.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
|
||||||
|
from tokenizer import CharTokenizer
|
||||||
|
from model import NoraTransformerLM
|
||||||
|
|
||||||
|
PERSONA_PATH = "nora_persona.txt"
|
||||||
|
|
||||||
|
# 1) A meta-prompt that explicitly tells Nora to invent a persona and avoid quoting:
|
||||||
|
PERSONA_META_PROMPT = (
|
||||||
|
"Below, Nora will create a brand‐new identity for herself. "
|
||||||
|
"She must NOT quote from any books or passages she has read. "
|
||||||
|
"Instead, she should invent her own style, voice, quirks, and personality traits as if she were a completely new person. "
|
||||||
|
"Her persona should be flirty, playful, curious, and speak in full sentences. "
|
||||||
|
"Write at least three paragraphs in your own words.\n\n"
|
||||||
|
"Nora, please invent and write your complete persona now:\n\nNora:"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_persona(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str) -> str:
|
||||||
|
"""
|
||||||
|
Ask Nora to write out her own, original persona, avoiding any verbatim quotes.
|
||||||
|
Returns the raw generated text.
|
||||||
|
"""
|
||||||
|
# We’ll ask for up to 512 tokens, with higher temperature and top_p sampling.
|
||||||
|
# That combination tends to produce more creative, less‐memorizable text.
|
||||||
|
raw = await asyncio.to_thread(
|
||||||
|
model.generate,
|
||||||
|
tokenizer,
|
||||||
|
device,
|
||||||
|
PERSONA_META_PROMPT,
|
||||||
|
512, # allow several paragraphs
|
||||||
|
1.2, # higher temperature for more creativity
|
||||||
|
0 # top_k=0 means no fixed-k; we’ll apply top_p filtering instead
|
||||||
|
)
|
||||||
|
|
||||||
|
# At this point, “raw” may include the word “Nora:” etc. Strip everything before “Nora:”
|
||||||
|
if "Nora:" in raw:
|
||||||
|
persona_text = raw.split("Nora:")[-1].strip()
|
||||||
|
else:
|
||||||
|
persona_text = raw.strip()
|
||||||
|
|
||||||
|
# Now apply a simple post‐filter: remove any long spans that match exact sequences in the book corpus.
|
||||||
|
# This is optional but helps ensure she didn’t copy large chunks verbatim. We check for 6+ character substrings
|
||||||
|
# appearing more than once in her output.
|
||||||
|
def remove_long_quotes(text: str) -> str:
|
||||||
|
filtered = text
|
||||||
|
# find any substring of length ≥6 that appears twice; we’ll just guess she’s quoting if it’s repeated.
|
||||||
|
for match in re.finditer(r"\b[\w',]{6,}\b", text):
|
||||||
|
substr = match.group(0)
|
||||||
|
if filtered.count(substr) > 1:
|
||||||
|
filtered = filtered.replace(substr, "[…]")
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
persona_text = remove_long_quotes(persona_text)
|
||||||
|
return persona_text
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_persona_file(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str):
|
||||||
|
"""
|
||||||
|
If nora_persona.txt does not exist, generate one (ensuring originality).
|
||||||
|
"""
|
||||||
|
if os.path.isfile(PERSONA_PATH):
|
||||||
|
return
|
||||||
|
|
||||||
|
print("[persona] No persona found. Generating a new, original persona…")
|
||||||
|
persona_text = asyncio.run(generate_persona(model, tokenizer, device))
|
||||||
|
|
||||||
|
# Save to disk
|
||||||
|
with open(PERSONA_PATH, "w", encoding="utf-8") as f:
|
||||||
|
f.write(persona_text)
|
||||||
|
print(f"[persona] Wrote new persona to {PERSONA_PATH}.")
|
||||||
|
|
||||||
|
|
||||||
|
async def maybe_update_persona(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str):
|
||||||
|
"""
|
||||||
|
Regenerate Nora’s persona if she requests it, overwriting the file.
|
||||||
|
"""
|
||||||
|
print("[persona] Updating persona at Nora's request…")
|
||||||
|
persona_text = await generate_persona(model, tokenizer, device)
|
||||||
|
with open(PERSONA_PATH, "w", encoding="utf-8") as f:
|
||||||
|
f.write(persona_text)
|
||||||
|
print(f"[persona] Updated persona in {PERSONA_PATH}.")
|
||||||
|
return persona_text
|
89
pretrain.py
Normal file
89
pretrain.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
"""
|
||||||
|
pretrain.py
|
||||||
|
|
||||||
|
Orchestrates the entire Nora project:
|
||||||
|
- Parses arguments
|
||||||
|
- Builds or loads tokenizer
|
||||||
|
- Constructs dataset & dataloader
|
||||||
|
- Instantiates the model
|
||||||
|
- Sets up optimizer, scheduler
|
||||||
|
- Calls train()
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
from config import get_config
|
||||||
|
from tokenizer import CharTokenizer
|
||||||
|
from data_loader import get_dataloader
|
||||||
|
from model import NoraTransformerLM
|
||||||
|
from train import train
|
||||||
|
from utils import setup_logging, load_checkpoint, save_checkpoint
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
|
||||||
|
|
||||||
|
def pretrain():
|
||||||
|
args = get_config()
|
||||||
|
|
||||||
|
# 1) Logging setup
|
||||||
|
log_file = os.path.join(args.checkpoint_dir, "train.log")
|
||||||
|
setup_logging(log_file)
|
||||||
|
|
||||||
|
logging.info(f"[pretrain] Using device: {args.device}")
|
||||||
|
logging.info(f"[pretrain] Config: {args}")
|
||||||
|
|
||||||
|
# 2) Tokenizer: if vocab exists, load; else build from data_dir
|
||||||
|
tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
|
||||||
|
|
||||||
|
# 3) DataLoader
|
||||||
|
dataloader = get_dataloader(
|
||||||
|
data_dir=args.data_dir,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
seq_length=args.seq_length,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4) Model instantiation
|
||||||
|
model = NoraTransformerLM(
|
||||||
|
vocab_size=tokenizer.vocab_size(),
|
||||||
|
d_model=args.d_model,
|
||||||
|
nhead=args.nhead,
|
||||||
|
num_layers=args.num_layers,
|
||||||
|
dim_feedforward=args.dim_feedforward,
|
||||||
|
dropout=args.dropout,
|
||||||
|
max_seq_len=args.seq_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5) Optimizer & scheduler (linear warmup + decay)
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
|
||||||
|
|
||||||
|
def lr_lambda(current_step):
|
||||||
|
# Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
|
||||||
|
if current_step < args.warmup_steps:
|
||||||
|
return float(current_step) / float(max(1, args.warmup_steps))
|
||||||
|
return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
|
||||||
|
|
||||||
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||||
|
|
||||||
|
# 6) Check for existing checkpoint to resume
|
||||||
|
start_step = 0
|
||||||
|
ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
|
||||||
|
ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
|
||||||
|
if ckpts:
|
||||||
|
latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
|
||||||
|
logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
|
||||||
|
start_step = load_checkpoint(latest_ckpt, model, optimizer)
|
||||||
|
|
||||||
|
# 7) Begin training
|
||||||
|
try:
|
||||||
|
train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("[main] Exception during training")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pretrain()
|
23
requirements.txt
Normal file
23
requirements.txt
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# Core ML framework
|
||||||
|
|
||||||
|
|
||||||
|
# Data loading / utilities
|
||||||
|
tqdm>=4.64.0
|
||||||
|
numpy>=1.22.0
|
||||||
|
|
||||||
|
# HuggingFace Datasets (PersonaChat, etc.)
|
||||||
|
datasets>=2.0.0
|
||||||
|
|
||||||
|
# ConvoKit (for Cornell Movie-Dialogs + Unsloth)
|
||||||
|
convokit>=2.0.0
|
||||||
|
|
||||||
|
# Discord bot
|
||||||
|
discord.py>=2.0.0
|
||||||
|
|
||||||
|
# Web scraping
|
||||||
|
beautifulsoup4>=4.11.0
|
||||||
|
requests>=2.28.0
|
||||||
|
|
||||||
|
# If ConvoKit pulls in TensorFlow via Unsloth
|
||||||
|
tensorflow>=2.10.0
|
||||||
|
|
175
self_improve.py
Normal file
175
self_improve.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
# self_improve.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tokenizer import CharTokenizer
|
||||||
|
from model import NoraTransformerLM
|
||||||
|
from config import get_config
|
||||||
|
|
||||||
|
# ------------------------------------------------------
|
||||||
|
# 1) “Teacher”: Pose a code‐generation prompt to Nora
|
||||||
|
# ------------------------------------------------------
|
||||||
|
def propose_patch(model, tokenizer, device, prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
Ask Nora to generate a code snippet given `prompt`.
|
||||||
|
e.g. prompt = "### FILE: knowledge_retriever.py\n# Add a new function clean_html(...) that..."
|
||||||
|
Returns the raw text (possibly including the prompt).
|
||||||
|
"""
|
||||||
|
raw = model.generate(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device=device,
|
||||||
|
prompt=prompt,
|
||||||
|
max_length=512,
|
||||||
|
temperature=0.7,
|
||||||
|
top_k=50,
|
||||||
|
)
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------
|
||||||
|
# 2) “Verifier” agent: sandbox + test
|
||||||
|
# ------------------------------------------------------
|
||||||
|
class CodeVerifier:
|
||||||
|
"""
|
||||||
|
Given a proposed code patch (as text), this class:
|
||||||
|
1. Writes it to a temporary file (or repo clone)
|
||||||
|
2. Runs Python’s syntax check (compile) and unit tests
|
||||||
|
3. Measures performance changes (e.g. run a small validation set through the model)
|
||||||
|
4. Returns True/False + log messages
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, repo_dir: str, test_command: str):
|
||||||
|
"""
|
||||||
|
- repo_dir: path to your Nora project root
|
||||||
|
- test_command: a shell command string to run unit tests, e.g. "pytest tests/"
|
||||||
|
"""
|
||||||
|
self.repo_dir = repo_dir
|
||||||
|
self.test_command = test_command
|
||||||
|
|
||||||
|
def verify_patch(self, rel_path: str, patch_code: str) -> bool:
|
||||||
|
"""
|
||||||
|
- rel_path: relative path inside repo where the patch should go, e.g. "knowledge_retriever.py"
|
||||||
|
- patch_code: entire contents of that file (not a diff).
|
||||||
|
Returns True if syntax + tests pass; False otherwise.
|
||||||
|
"""
|
||||||
|
# 1) Copy repo => temp dir
|
||||||
|
tmpdir = tempfile.mkdtemp(prefix="nora_verify_")
|
||||||
|
try:
|
||||||
|
shutil.copytree(self.repo_dir, os.path.join(tmpdir, "repo"), dirs_exist_ok=True)
|
||||||
|
target_file = os.path.join(tmpdir, "repo", rel_path)
|
||||||
|
|
||||||
|
# 2) Write patch_code to target_file
|
||||||
|
with open(target_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(patch_code)
|
||||||
|
|
||||||
|
# 3) Syntax check (try compiling)
|
||||||
|
try:
|
||||||
|
compile(patch_code, target_file, "exec")
|
||||||
|
except SyntaxError as se:
|
||||||
|
logging.error(f"Syntax error in patch: {se}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 4) Run unit tests
|
||||||
|
result = subprocess.run(
|
||||||
|
self.test_command,
|
||||||
|
shell=True,
|
||||||
|
cwd=os.path.join(tmpdir, "repo"),
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
logging.error(f"Unit tests failed:\n{result.stdout}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 5) (Optional) Performance check
|
||||||
|
# You could load the updated model and measure perplexity on a tiny validation set here.
|
||||||
|
# For now, we assume passing tests = “improvement.”
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
def merge_patch(self, rel_path: str, patch_code: str) -> None:
|
||||||
|
"""
|
||||||
|
Overwrite the real file in `repo_dir/rel_path` with patch_code,
|
||||||
|
then git-add and git-commit (you can also automate a PR).
|
||||||
|
"""
|
||||||
|
target_file = os.path.join(self.repo_dir, rel_path)
|
||||||
|
with open(target_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(patch_code)
|
||||||
|
|
||||||
|
# Example: git add + commit
|
||||||
|
subprocess.run(f"git add {rel_path}", shell=True, cwd=self.repo_dir)
|
||||||
|
subprocess.run(
|
||||||
|
f'git commit -m "Auto-update {rel_path} via Nora self-improve."',
|
||||||
|
shell=True,
|
||||||
|
cwd=self.repo_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------
|
||||||
|
# 3) Main loop: ask → verify → merge (if good) → retrain
|
||||||
|
# ------------------------------------------------------
|
||||||
|
def self_improvement_cycle(repo_dir: str, device: str):
|
||||||
|
"""
|
||||||
|
Example cycle:
|
||||||
|
1) Nora proposes a new helper in knowledge_retriever.py
|
||||||
|
2) Verifier checks syntax + tests
|
||||||
|
3) If ok, merge and trigger incremental retraining
|
||||||
|
"""
|
||||||
|
config = get_config()
|
||||||
|
tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
|
||||||
|
model = NoraTransformerLM(
|
||||||
|
vocab_size=tokenizer.vocab_size(),
|
||||||
|
d_model=config.d_model,
|
||||||
|
nhead=config.nhead,
|
||||||
|
num_layers=config.num_layers,
|
||||||
|
dim_feedforward=config.dim_feedforward,
|
||||||
|
dropout=config.dropout,
|
||||||
|
max_seq_len=config.seq_length,
|
||||||
|
)
|
||||||
|
# Load latest checkpoint
|
||||||
|
ckpts = []
|
||||||
|
if os.path.isdir(config.checkpoint_dir):
|
||||||
|
ckpts = [
|
||||||
|
f
|
||||||
|
for f in os.listdir(config.checkpoint_dir)
|
||||||
|
if f.startswith("nora_step_") and f.endswith(".pt")
|
||||||
|
]
|
||||||
|
if ckpts:
|
||||||
|
latest = sorted(ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1]
|
||||||
|
state = torch.load(os.path.join(config.checkpoint_dir, latest), map_location="cpu")
|
||||||
|
model.load_state_dict(state["model_state_dict"])
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
verifier = CodeVerifier(repo_dir=repo_dir, test_command="pytest tests/")
|
||||||
|
|
||||||
|
# Example prompt: ask Nora to extend knowledge_retriever.py
|
||||||
|
prompt = (
|
||||||
|
"### FILE: knowledge_retriever.py\n"
|
||||||
|
"# Add a function clean_html(html: str) -> str that strips tags and scripts.\n"
|
||||||
|
"# Use BeautifulSoup if available. Return plain text.\n\n"
|
||||||
|
"### START\n"
|
||||||
|
"def clean_html(html: str) -> str:\n"
|
||||||
|
)
|
||||||
|
raw_patch = propose_patch(model, tokenizer, device, prompt)
|
||||||
|
|
||||||
|
# Extract everything from “def clean_html” to end of function (simple heuristic)
|
||||||
|
# In practice, you’d parse until the next “\n\n” or rely on indentation.
|
||||||
|
patch_code = raw_patch # for now, assume raw_patch is the full file contents
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
if verifier.verify_patch("knowledge_retriever.py", patch_code):
|
||||||
|
logging.info("Patch verified. Merging into live code.")
|
||||||
|
verifier.merge_patch("knowledge_retriever.py", patch_code)
|
||||||
|
# Optionally: trigger incremental retraining here (e.g. call train.py with --resume)
|
||||||
|
else:
|
||||||
|
logging.warning("Patch failed verification. Discarding.")
|
7
train.py
7
train.py
@ -37,6 +37,13 @@ def train(
|
|||||||
|
|
||||||
device = config.device
|
device = config.device
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
|
# ─── ensure optimizer state is on the same device ───
|
||||||
|
# (this moves any loaded CPU buffers for Adam/AdamW into CUDA)
|
||||||
|
for state in optimizer.state.values():
|
||||||
|
for k, v in state.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
state[k] = v.to(device)
|
||||||
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<pad>"])
|
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<pad>"])
|
||||||
scaler = GradScaler()
|
scaler = GradScaler()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user