Added another learning source for Nora. Also added the requirements.
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@ -196,3 +196,5 @@ cython_debug/
|
||||
checkpoints/nora_step_*.pt
|
||||
data/books
|
||||
*.json
|
||||
data/conversational/cornell_movie_dialogs.txt
|
||||
.gitignore
|
||||
|
@ -67,7 +67,7 @@ def get_config():
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
default=512,
|
||||
help="Batch size per training step.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
0
data/conversational/persona_chat.txt
Normal file
0
data/conversational/persona_chat.txt
Normal file
@ -21,15 +21,25 @@ class TextDataset(Dataset):
|
||||
self.seq_length = seq_length
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Read and concatenate all text files into one long string
|
||||
# Read and concatenate all .txt files under two folders:
|
||||
# - data/books/
|
||||
# - data/conversational/
|
||||
texts = []
|
||||
for root, _, files in os.walk(data_dir):
|
||||
for fname in files:
|
||||
if not fname.lower().endswith(".txt"):
|
||||
continue
|
||||
path = os.path.join(root, fname)
|
||||
with open(path, "r", encoding="utf-8", errors="ignore") as f:
|
||||
texts.append(f.read())
|
||||
# If data_dir is a single path, we still look for a sibling "conversational" folder
|
||||
conversational_dir = os.path.join(os.path.dirname(data_dir), "conversational")
|
||||
# Gather all folders to walk
|
||||
dirs_to_walk = [data_dir]
|
||||
if os.path.isdir(conversational_dir):
|
||||
dirs_to_walk.append(conversational_dir)
|
||||
|
||||
for d in dirs_to_walk:
|
||||
for root, _, files in os.walk(d):
|
||||
for fname in files:
|
||||
if not fname.lower().endswith(".txt"):
|
||||
continue
|
||||
path = os.path.join(root, fname)
|
||||
with open(path, "r", encoding="utf-8", errors="ignore") as f:
|
||||
texts.append(f.read())
|
||||
full_text = "\n".join(texts)
|
||||
token_ids = self.tokenizer.encode(full_text)
|
||||
|
||||
|
217
data_prep.py
Normal file
217
data_prep.py
Normal file
@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
data_prep.py
|
||||
|
||||
1) Attempts to download Cornell Movie-Dialogs via ConvoKit (key: "movie-corpus").
|
||||
- If ConvoKit/Unsloth fails, falls back to manual ZIP download/extraction.
|
||||
|
||||
2) Attempts to download PersonaChat via Hugging Face Datasets:
|
||||
- First tries "persona_chat" (older key).
|
||||
- If that fails, tries "conv_ai_2" (alias).
|
||||
- Catches any exception to skip gracefully.
|
||||
|
||||
3) Writes each utterance to:
|
||||
data/conversational/cornell_movie_dialogs.txt
|
||||
data/conversational/persona_chat.txt
|
||||
|
||||
After running, you’ll have:
|
||||
data/
|
||||
├── books/ (your original Gutenberg .txt files)
|
||||
└── conversational/
|
||||
├── cornell_movie_dialogs.txt
|
||||
└── persona_chat.txt
|
||||
|
||||
Then retrain or fine-tune Nora on data/books/ + data/conversational/.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import zipfile
|
||||
import tempfile
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
# === 1) Attempt to import ConvoKit for Cornell Movie-Dialogs ===
|
||||
USE_CONVOKIT = True
|
||||
try:
|
||||
from convokit import Corpus, download as convokit_download
|
||||
except ImportError:
|
||||
USE_CONVOKIT = False
|
||||
|
||||
# === 2) Attempt to import Hugging Face Datasets ===
|
||||
HAS_DATASETS = True
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
except ImportError:
|
||||
HAS_DATASETS = False
|
||||
|
||||
# Directory for conversational data
|
||||
CONV_DIR = Path("data/conversational")
|
||||
CONV_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Official ZIP URL (fallback) for Cornell Movie-Dialogs
|
||||
CORNELL_ZIP_URL = "https://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip"
|
||||
|
||||
|
||||
def install_package(pkg_name: str):
|
||||
"""
|
||||
Installs a Python package using the same Python interpreter,
|
||||
wrapping the path in quotes to handle spaces.
|
||||
"""
|
||||
python_executable = sys.executable
|
||||
command = f"\"{python_executable}\" -m pip install {pkg_name}"
|
||||
print(f"[data_prep] Installing package: {pkg_name}")
|
||||
os.system(command)
|
||||
|
||||
|
||||
def prepare_cornell_via_convokit(output_path: str) -> bool:
|
||||
"""
|
||||
Try to download Cornell Movie-Dialogs via ConvoKit (key: "movie-corpus").
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
if not USE_CONVOKIT:
|
||||
print("[data_prep] ConvoKit not installed; skipping ConvoKit path.")
|
||||
return False
|
||||
|
||||
print("[data_prep] Attempting to download Cornell Movie-Dialogs via ConvoKit...")
|
||||
try:
|
||||
corpus = Corpus(filename=convokit_download("movie-corpus"))
|
||||
with open(output_path, "w", encoding="utf-8") as fout:
|
||||
for utt in corpus.iter_utterances():
|
||||
text = utt.text.strip()
|
||||
if text:
|
||||
fout.write(text.replace("\n", " ") + "\n")
|
||||
print(f"[data_prep] Wrote Cornell Movie-Dialogs to {output_path} (via ConvoKit).")
|
||||
return True
|
||||
|
||||
except NotImplementedError as nie:
|
||||
# Typically due to Unsloth error if GPU unsupported
|
||||
print("[data_prep] ConvoKit raised NotImplementedError (Unsloth/GPU issue).")
|
||||
print(f"[data_prep] Error: {nie}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print("[data_prep] ConvoKit path failed with exception:", file=sys.stderr)
|
||||
print(e, file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
def prepare_cornell_manual(output_path: str):
|
||||
"""
|
||||
Fallback: Download Cornell ZIP manually, extract movie_lines.txt,
|
||||
and write all utterances to output_path.
|
||||
"""
|
||||
print("[data_prep] Falling back to manual download of Cornell Movie-Dialogs...")
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
zip_path = os.path.join(tmpdir, "cornell.zip")
|
||||
try:
|
||||
print(f"[data_prep] Downloading from {CORNELL_ZIP_URL} ...")
|
||||
urllib.request.urlretrieve(CORNELL_ZIP_URL, zip_path)
|
||||
except Exception as e:
|
||||
print(f"[data_prep] Error downloading Cornell corpus: {e}", file=sys.stderr)
|
||||
return
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as z:
|
||||
member_name = None
|
||||
for name in z.namelist():
|
||||
if name.endswith("movie_lines.txt"):
|
||||
member_name = name
|
||||
break
|
||||
if member_name is None:
|
||||
print("[data_prep] movie_lines.txt not found in ZIP.", file=sys.stderr)
|
||||
return
|
||||
z.extract(member_name, path=tmpdir)
|
||||
extracted_path = os.path.join(tmpdir, member_name)
|
||||
except Exception as e:
|
||||
print(f"[data_prep] Error extracting ZIP: {e}", file=sys.stderr)
|
||||
return
|
||||
|
||||
try:
|
||||
with open(extracted_path, "r", encoding="iso-8859-1", errors="ignore") as fin, open(
|
||||
output_path, "w", encoding="utf-8"
|
||||
) as fout:
|
||||
for line in fin:
|
||||
parts = line.split(" +++$+++ ")
|
||||
if len(parts) == 5:
|
||||
text = parts[-1].strip()
|
||||
if text:
|
||||
fout.write(text.replace("\n", " ") + "\n")
|
||||
except Exception as e:
|
||||
print(f"[data_prep] Error parsing movie_lines.txt: {e}", file=sys.stderr)
|
||||
return
|
||||
|
||||
print(f"[data_prep] Wrote Cornell Movie-Dialogs to {output_path} (manual).")
|
||||
|
||||
|
||||
def prepare_personachat(output_path: str):
|
||||
"""
|
||||
Attempt to download PersonaChat via Hugging Face Datasets.
|
||||
Tries "persona_chat" and then "conv_ai_2". Catches any exception.
|
||||
"""
|
||||
if not HAS_DATASETS:
|
||||
install_package("datasets")
|
||||
global load_dataset
|
||||
from datasets import load_dataset
|
||||
# Now we have it
|
||||
for dataset_key in ["persona_chat", "conv_ai_2"]:
|
||||
try:
|
||||
print(f"[data_prep] Attempting to load '{dataset_key}' via Hugging Face Datasets...")
|
||||
if dataset_key == "conv_ai_2":
|
||||
dataset = load_dataset(dataset_key, trust_remote_code=True)
|
||||
else:
|
||||
dataset = load_dataset(dataset_key)
|
||||
print(f"[data_prep] Successfully loaded '{dataset_key}'. Writing to {output_path}...")
|
||||
with open(output_path, "w", encoding="utf-8") as fout:
|
||||
if dataset_key == "persona_chat":
|
||||
for split in ["train", "valid"]:
|
||||
for conv in dataset[split]:
|
||||
for line in conv["dialog"]:
|
||||
text = line.strip()
|
||||
if text:
|
||||
fout.write(text.replace("\n", " ") + "\n")
|
||||
else: # conv_ai_2
|
||||
for split in ["train", "valid"]:
|
||||
for item in dataset[split]:
|
||||
# conv_ai_2 has a field named "dialog"
|
||||
if "dialog" in item:
|
||||
for line in item["dialog"]:
|
||||
text = line.strip()
|
||||
if text:
|
||||
fout.write(text.replace("\n", " ") + "\n")
|
||||
elif "utterance" in item:
|
||||
text = item["utterance"].strip()
|
||||
if text:
|
||||
fout.write(text.replace("\n", " ") + "\n")
|
||||
print(f"[data_prep] Finished writing PersonaChat ({dataset_key}) to {output_path}.")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"[data_prep] Failed '{dataset_key}': {e}", file=sys.stderr)
|
||||
# Try next key
|
||||
|
||||
print("[data_prep] Could not load PersonaChat under any key. Skipping PersonaChat.", file=sys.stderr)
|
||||
|
||||
|
||||
def main():
|
||||
cornell_path = CONV_DIR / "cornell_movie_dialogs.txt"
|
||||
persona_path = CONV_DIR / "persona_chat.txt"
|
||||
|
||||
# 1) Prepare Cornell Movie-Dialogs: try ConvoKit, then manual
|
||||
if not cornell_path.is_file():
|
||||
ok = prepare_cornell_via_convokit(str(cornell_path))
|
||||
if not ok:
|
||||
prepare_cornell_manual(str(cornell_path))
|
||||
else:
|
||||
print(f"[data_prep] Skipping Cornell: '{cornell_path}' already exists.")
|
||||
|
||||
# 2) Prepare PersonaChat
|
||||
if not persona_path.is_file():
|
||||
prepare_personachat(str(persona_path))
|
||||
else:
|
||||
print(f"[data_prep] Skipping PersonaChat: '{persona_path}' already exists.")
|
||||
|
||||
print("[data_prep] All done. You can now include data/conversational/ in Nora's training.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
107
knowledge_retriever.py
Normal file
107
knowledge_retriever.py
Normal file
@ -0,0 +1,107 @@
|
||||
# knowledge_retriever.py
|
||||
|
||||
import os
|
||||
import re
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
import logging
|
||||
|
||||
# Where to dump new “.txt” files scraped from the web
|
||||
SCRAPE_DIR = "data/books/scraped"
|
||||
|
||||
# Simple rate‐limiter to avoid hammering any one domain
|
||||
import time
|
||||
_last_request_time = {}
|
||||
|
||||
|
||||
def fetch_url(url: str, min_interval: float = 1.0) -> str:
|
||||
"""
|
||||
Fetch a URL’s HTML, enforcing at least `min_interval` seconds between requests
|
||||
to the same domain. Returns HTML string or empty string on failure.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
domain = urlparse(url).netloc
|
||||
now = time.time()
|
||||
last = _last_request_time.get(domain, 0)
|
||||
wait = min_interval - (now - last)
|
||||
if wait > 0:
|
||||
time.sleep(wait)
|
||||
|
||||
headers = {"User-Agent": "NoraScraper/1.0 (+https://your_project_url)"}
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
_last_request_time[domain] = time.time()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
logging.error(f"Error fetching {url}: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def clean_html(html: str) -> str:
|
||||
"""
|
||||
Strip scripts, styles, and tags; return plain text.
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
|
||||
# remove scripts and styles
|
||||
for tag in soup(["script", "style", "noscript"]):
|
||||
tag.decompose()
|
||||
|
||||
text = soup.get_text(separator="\n")
|
||||
# collapse multiple blank lines
|
||||
lines = [line.strip() for line in text.splitlines()]
|
||||
text = "\n".join([line for line in lines if line])
|
||||
return text
|
||||
|
||||
|
||||
def save_text(content: str, title: str):
|
||||
"""
|
||||
Save content to a UTF-8 .txt file under SCRAPE_DIR. Filename is derived from title.
|
||||
"""
|
||||
os.makedirs(SCRAPE_DIR, exist_ok=True)
|
||||
# sanitize title → filename
|
||||
safe = re.sub(r"[^0-9a-zA-Z_\-]", "_", title)
|
||||
fname = f"{safe[:50]}.txt"
|
||||
path = os.path.join(SCRAPE_DIR, fname)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
logging.info(f"Saved scraped page to {path}")
|
||||
|
||||
|
||||
def scrape_and_store(url: str):
|
||||
"""
|
||||
High-level function: fetches URL, cleans HTML, extracts a title, and saves to a .txt.
|
||||
"""
|
||||
html = fetch_url(url)
|
||||
if not html:
|
||||
return False
|
||||
|
||||
text = clean_html(html)
|
||||
# extract <title> if present
|
||||
title = ""
|
||||
m = re.search(r"<title>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
|
||||
if m:
|
||||
title = m.group(1).strip()
|
||||
else:
|
||||
title = url
|
||||
|
||||
save_text(text, title)
|
||||
return True
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python knowledge_retriever.py <url1> [<url2> ...]")
|
||||
sys.exit(1)
|
||||
|
||||
for link in sys.argv[1:]:
|
||||
success = scrape_and_store(link)
|
||||
if success:
|
||||
print(f"Scraped: {link}")
|
||||
else:
|
||||
print(f"Failed to scrape: {link}")
|
359
main.py
359
main.py
@ -1,89 +1,310 @@
|
||||
"""
|
||||
main.py
|
||||
|
||||
Orchestrates the entire Nora project:
|
||||
- Parses arguments
|
||||
- Builds or loads tokenizer
|
||||
- Constructs dataset & dataloader
|
||||
- Instantiates the model
|
||||
- Sets up optimizer, scheduler
|
||||
- Calls train()
|
||||
"""
|
||||
# main.py
|
||||
|
||||
import os
|
||||
import torch
|
||||
import asyncio
|
||||
import logging
|
||||
import subprocess
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import discord
|
||||
import torch
|
||||
from discord import Intents
|
||||
|
||||
from config import get_config
|
||||
from tokenizer import CharTokenizer
|
||||
from model import NoraTransformerLM, top_k_logits
|
||||
from data_loader import get_dataloader
|
||||
from model import NoraTransformerLM
|
||||
from train import train
|
||||
from utils import setup_logging, load_checkpoint, save_checkpoint
|
||||
from knowledge_retriever import fetch_url, clean_html, save_text
|
||||
import persona_manager # <— PERSONA CHANGE
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
# Logging setup
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s] [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Keep the SCRAPE regex as before
|
||||
SCRAPE_PATTERN = re.compile(r"<<SCRAPE:(https?://[^\s>]+)>>", re.IGNORECASE)
|
||||
|
||||
def main():
|
||||
args = get_config()
|
||||
# ---------------------------------------------------
|
||||
# 1) Build or Reload Nora’s Model & Tokenizer
|
||||
# ---------------------------------------------------
|
||||
def build_nora(config, device):
|
||||
"""
|
||||
- Loads tokenizer
|
||||
- Instantiates NoraTransformerLM
|
||||
- Loads latest checkpoint (if any)
|
||||
"""
|
||||
tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
|
||||
|
||||
# 1) Logging setup
|
||||
log_file = os.path.join(args.checkpoint_dir, "train.log")
|
||||
setup_logging(log_file)
|
||||
|
||||
logging.info(f"[main] Using device: {args.device}")
|
||||
logging.info(f"[main] Config: {args}")
|
||||
|
||||
# 2) Tokenizer: if vocab exists, load; else build from data_dir
|
||||
tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
|
||||
|
||||
# 3) DataLoader
|
||||
dataloader = get_dataloader(
|
||||
data_dir=args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
seq_length=args.seq_length,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
# 4) Model instantiation
|
||||
model = NoraTransformerLM(
|
||||
vocab_size=tokenizer.vocab_size(),
|
||||
d_model=args.d_model,
|
||||
nhead=args.nhead,
|
||||
num_layers=args.num_layers,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
dropout=args.dropout,
|
||||
max_seq_len=args.seq_length,
|
||||
d_model=config.d_model,
|
||||
nhead=config.nhead,
|
||||
num_layers=config.num_layers,
|
||||
dim_feedforward=config.dim_feedforward,
|
||||
dropout=config.dropout,
|
||||
max_seq_len=config.seq_length,
|
||||
)
|
||||
|
||||
# 5) Optimizer & scheduler (linear warmup + decay)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
|
||||
# Find latest checkpoint
|
||||
latest_ckpt = None
|
||||
if os.path.isdir(config.checkpoint_dir):
|
||||
ckpts = [
|
||||
f
|
||||
for f in os.listdir(config.checkpoint_dir)
|
||||
if f.startswith("nora_step_") and f.endswith(".pt")
|
||||
]
|
||||
if ckpts:
|
||||
latest_ckpt = sorted(
|
||||
ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
|
||||
)[-1]
|
||||
|
||||
def lr_lambda(current_step):
|
||||
# Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
|
||||
if current_step < args.warmup_steps:
|
||||
return float(current_step) / float(max(1, args.warmup_steps))
|
||||
return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
|
||||
if latest_ckpt:
|
||||
ckpt_path = os.path.join(config.checkpoint_dir, latest_ckpt)
|
||||
state = torch.load(ckpt_path, map_location="cpu")
|
||||
model.load_state_dict(state["model_state_dict"])
|
||||
logger.info(f"Loaded Nora from checkpoint {latest_ckpt}")
|
||||
else:
|
||||
logger.warning("No checkpoints found. Nora starts untrained.")
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
# 6) Check for existing checkpoint to resume
|
||||
start_step = 0
|
||||
ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
|
||||
ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
|
||||
if ckpts:
|
||||
latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
|
||||
logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
|
||||
start_step = load_checkpoint(latest_ckpt, model, optimizer)
|
||||
|
||||
# 7) Begin training
|
||||
try:
|
||||
train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
|
||||
except Exception as e:
|
||||
logging.exception("[main] Exception during training")
|
||||
raise e
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
# ---------------------------------------------------
|
||||
# 2) Autoregressive Sampling Function
|
||||
# ---------------------------------------------------
|
||||
def generate_text(
|
||||
model: NoraTransformerLM,
|
||||
tokenizer: CharTokenizer,
|
||||
device: str,
|
||||
prompt: str,
|
||||
max_length: int = 128,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> str:
|
||||
"""
|
||||
Wrapper around model.generate. Returns raw generated string.
|
||||
"""
|
||||
return model.generate(
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
prompt=prompt,
|
||||
max_length=max_length,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------
|
||||
# 3) SCRAPE Directive Handler (same as before)
|
||||
# ---------------------------------------------------
|
||||
async def handle_scrape_directives(text: str, data_dir: str) -> bool:
|
||||
urls = set(m.group(1) for m in SCRAPE_PATTERN.finditer(text))
|
||||
if not urls:
|
||||
return False
|
||||
|
||||
for url in urls:
|
||||
logger.info(f"Directive: Scraping {url}")
|
||||
html = fetch_url(url)
|
||||
if not html:
|
||||
logger.error(f"Failed to fetch {url}")
|
||||
continue
|
||||
plain = clean_html(html)
|
||||
title = ""
|
||||
m = re.search(r"<title>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
|
||||
if m:
|
||||
title = m.group(1).strip()[:50]
|
||||
else:
|
||||
title = urlparse(url).netloc
|
||||
save_text(plain, title)
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------
|
||||
# 4) Discord Client (“Nora” class) with Persona
|
||||
# ---------------------------------------------------
|
||||
class Nora(discord.Client):
|
||||
def __init__(self, model, tokenizer, config, device):
|
||||
intents = Intents.default()
|
||||
intents.messages = True
|
||||
intents.message_content = True
|
||||
super().__init__(intents=intents)
|
||||
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.config = config
|
||||
self.device = device
|
||||
|
||||
# history per channel (last 5 user/nora pairs)
|
||||
self.history = {}
|
||||
|
||||
# Load Nora’s persona from disk (it’s guaranteed to exist by startup)
|
||||
with open(persona_manager.PERSONA_PATH, "r", encoding="utf-8") as f:
|
||||
self.persona_text = f.read()
|
||||
|
||||
async def on_ready(self):
|
||||
logger.info(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||
logger.info("Nora is online and ready to be herself.")
|
||||
|
||||
# Background task: reload model if a new checkpoint shows up
|
||||
self.loop.create_task(self._reload_model_periodically())
|
||||
|
||||
async def _reload_model_periodically(self, interval: int = 600):
|
||||
"""
|
||||
Every `interval` seconds, check for newer checkpoint & reload.
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(interval)
|
||||
ckpts = [
|
||||
f
|
||||
for f in os.listdir(self.config.checkpoint_dir)
|
||||
if f.startswith("nora_step_") and f.endswith(".pt")
|
||||
]
|
||||
if not ckpts:
|
||||
continue
|
||||
latest = sorted(
|
||||
ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
|
||||
)[-1]
|
||||
ckpt_path = os.path.join(self.config.checkpoint_dir, latest)
|
||||
state = torch.load(ckpt_path, map_location="cpu")
|
||||
self.model.load_state_dict(state["model_state_dict"])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
logger.info(f"Reloaded Nora’s model from {latest}")
|
||||
|
||||
async def on_message(self, message: discord.Message):
|
||||
# 1) Ignore self-messages
|
||||
if message.author.id == self.user.id:
|
||||
return
|
||||
|
||||
content = message.content.strip()
|
||||
prompt = None
|
||||
|
||||
# 2) If in DM, treat entire content as prompt
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
prompt = content
|
||||
|
||||
# Also allow “update persona” in DM to regenerate persona file
|
||||
if content.lower().startswith("update persona"):
|
||||
# Regenerate persona asynchronously
|
||||
new_persona = await persona_manager.maybe_update_persona(
|
||||
self.model, self.tokenizer, self.device
|
||||
)
|
||||
self.persona_text = new_persona
|
||||
await message.channel.send(
|
||||
"I have rewritten my persona. Thank you! ❤️"
|
||||
)
|
||||
return
|
||||
|
||||
# 3) Otherwise (guild), require mention or “nora,” prefix
|
||||
else:
|
||||
if self.user.mention in content:
|
||||
prompt = content.replace(self.user.mention, "").strip()
|
||||
elif content.lower().startswith("nora,"):
|
||||
prompt = content[len("nora,"):].strip()
|
||||
else:
|
||||
return
|
||||
|
||||
if not prompt:
|
||||
return # e.g. user only said “Nora,” with no text after
|
||||
|
||||
# 4) Show typing indicator if in a guild text channel
|
||||
if isinstance(message.channel, discord.TextChannel):
|
||||
await message.channel.trigger_typing()
|
||||
|
||||
# 5) Build the full prompt: persona + history + user’s prompt
|
||||
chan_id = str(message.channel.id)
|
||||
history = self.history.get(chan_id, [])
|
||||
|
||||
prompt_lines = []
|
||||
# 5.1) Insert Nora’s persona (so she “speaks as herself”)
|
||||
prompt_lines.append(self.persona_text)
|
||||
prompt_lines.append("") # blank line
|
||||
|
||||
# 5.2) Insert up to the last 4 exchanges
|
||||
for user_msg, nora_msg in history[-4:]:
|
||||
prompt_lines.append(f"User: {user_msg}")
|
||||
prompt_lines.append(f"Nora: {nora_msg}")
|
||||
prompt_lines.append("")
|
||||
|
||||
# 5.3) Finally, insert the new user prompt
|
||||
prompt_lines.append(f"User: {prompt}")
|
||||
prompt_lines.append("Nora:")
|
||||
|
||||
conversation_prompt = "\n".join(prompt_lines)
|
||||
|
||||
# 6) Generate Nora’s reply (tighter sampling: temp=0.8, top_k=20)
|
||||
try:
|
||||
raw = await asyncio.to_thread(
|
||||
self.model.generate,
|
||||
self.tokenizer,
|
||||
self.device,
|
||||
conversation_prompt,
|
||||
self.config.seq_length,
|
||||
0.8, # temperature
|
||||
20, # top_k
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in Nora.generate()")
|
||||
await message.channel.send("😔 Sorry, I hit an error trying to think.")
|
||||
return
|
||||
|
||||
# 7) Extract just Nora’s reply text
|
||||
if "Nora:" in raw:
|
||||
nora_reply = raw.split("Nora:")[-1].strip()
|
||||
else:
|
||||
nora_reply = raw[len(conversation_prompt) :].strip()
|
||||
|
||||
# 8) Handle <<SCRAPE:…>> if present
|
||||
did_scrape = await handle_scrape_directives(raw, self.config.data_dir)
|
||||
if did_scrape:
|
||||
logger.info("Detected scrape directive. Triggering incremental retrain.")
|
||||
subprocess.Popen(
|
||||
["python", "pretrain.py", "--resume"], # note: pretrain.py was your old main.py
|
||||
cwd=os.getcwd(),
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
nora_reply += "\n\n*I found some new things and am updating myself…*"
|
||||
|
||||
# 9) Save to history and prune to the last 5
|
||||
self.history.setdefault(chan_id, []).append((prompt, nora_reply))
|
||||
self.history[chan_id] = self.history[chan_id][-5 :]
|
||||
|
||||
# 10) Send Nora’s reply
|
||||
await message.channel.send(nora_reply)
|
||||
|
||||
|
||||
# ---------------------------------------------------
|
||||
# 4) Entrypoint
|
||||
# ---------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
config = get_config()
|
||||
device = config.device
|
||||
|
||||
# 4.1) Build/load model & tokenizer
|
||||
model, tokenizer = build_nora(config, device)
|
||||
|
||||
# 4.2) Ensure a persona exists—if not, generate one now
|
||||
persona_manager.ensure_persona_file(model, tokenizer, device)
|
||||
|
||||
# 4.3) After that, we can proceed to start the agent
|
||||
discord_token = os.getenv("DISCORD_TOKEN")
|
||||
if not discord_token:
|
||||
logger.error("Please set DISCORD_TOKEN in your environment.")
|
||||
exit(1)
|
||||
|
||||
# Enable CuDNN autotune if on CUDA
|
||||
if device.startswith("cuda"):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
bot = Nora(model=model, tokenizer=tokenizer, config=config, device=device)
|
||||
bot.run(discord_token)
|
||||
|
51
model.py
51
model.py
@ -7,9 +7,22 @@ No pretrained weights—everything is initialized randomly.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
def top_k_logits(logits: torch.Tensor, k: int):
|
||||
"""
|
||||
Zero out all but the top k logits in each row; return modified logits.
|
||||
logits: (vocab_size,)
|
||||
"""
|
||||
if k == 0:
|
||||
return logits
|
||||
topk_vals, _ = torch.topk(logits, k)
|
||||
min_topk = topk_vals[-1]
|
||||
return torch.where(logits < min_topk, torch.full_like(logits, -1e10), logits)
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model: int, max_len: int = 10_000):
|
||||
super().__init__()
|
||||
@ -98,3 +111,41 @@ class NoraTransformerLM(nn.Module):
|
||||
x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model)
|
||||
logits = self.fc_out(x) # (batch_size, seq_length, vocab_size)
|
||||
return logits
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tokenizer,
|
||||
device: str,
|
||||
prompt: str,
|
||||
max_length: int = 128,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> str:
|
||||
"""
|
||||
Autoregressively generate text from a prompt.
|
||||
- tokenizer: CharTokenizer (for encode/decode)
|
||||
- device: "cuda" or "cpu"
|
||||
- prompt: initial string
|
||||
- max_length: total tokens to generate (including prompt)
|
||||
- temperature: scales logits before softmax
|
||||
- top_k: keep only top_k logits at each step
|
||||
"""
|
||||
self.eval()
|
||||
input_ids = tokenizer.encode(prompt)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
|
||||
generated = input_ids.clone() # shape (1, seq_len)
|
||||
for _ in range(max_length - input_ids.size(1)):
|
||||
# 1) trim to last seq_length tokens if longer than context window
|
||||
if generated.size(1) > self.pos_encoder.pe.size(1):
|
||||
generated = generated[:, -self.pos_encoder.pe.size(1) :]
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.forward(generated) # (1, seq_len, vocab_size)
|
||||
next_token_logits = logits[0, -1, :] / temperature
|
||||
filtered_logits = top_k_logits(next_token_logits, k=top_k)
|
||||
probs = F.softmax(filtered_logits, dim=-1)
|
||||
next_id = torch.multinomial(probs, num_samples=1) # (1,)
|
||||
generated = torch.cat([generated, next_id.unsqueeze(0)], dim=1)
|
||||
|
||||
output_ids = generated.squeeze(0).tolist()
|
||||
return tokenizer.decode(output_ids)
|
2
nora_persona.txt
Normal file
2
nora_persona.txt
Normal file
@ -0,0 +1,2 @@
|
||||
if you conduct, ens you heel to anyone. As for we know the bold child
|
||||
like me!”
|
88
persona_manager.py
Normal file
88
persona_manager.py
Normal file
@ -0,0 +1,88 @@
|
||||
# persona_manager.py
|
||||
|
||||
import os
|
||||
import torch
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from tokenizer import CharTokenizer
|
||||
from model import NoraTransformerLM
|
||||
|
||||
PERSONA_PATH = "nora_persona.txt"
|
||||
|
||||
# 1) A meta-prompt that explicitly tells Nora to invent a persona and avoid quoting:
|
||||
PERSONA_META_PROMPT = (
|
||||
"Below, Nora will create a brand‐new identity for herself. "
|
||||
"She must NOT quote from any books or passages she has read. "
|
||||
"Instead, she should invent her own style, voice, quirks, and personality traits as if she were a completely new person. "
|
||||
"Her persona should be flirty, playful, curious, and speak in full sentences. "
|
||||
"Write at least three paragraphs in your own words.\n\n"
|
||||
"Nora, please invent and write your complete persona now:\n\nNora:"
|
||||
)
|
||||
|
||||
|
||||
async def generate_persona(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str) -> str:
|
||||
"""
|
||||
Ask Nora to write out her own, original persona, avoiding any verbatim quotes.
|
||||
Returns the raw generated text.
|
||||
"""
|
||||
# We’ll ask for up to 512 tokens, with higher temperature and top_p sampling.
|
||||
# That combination tends to produce more creative, less‐memorizable text.
|
||||
raw = await asyncio.to_thread(
|
||||
model.generate,
|
||||
tokenizer,
|
||||
device,
|
||||
PERSONA_META_PROMPT,
|
||||
512, # allow several paragraphs
|
||||
1.2, # higher temperature for more creativity
|
||||
0 # top_k=0 means no fixed-k; we’ll apply top_p filtering instead
|
||||
)
|
||||
|
||||
# At this point, “raw” may include the word “Nora:” etc. Strip everything before “Nora:”
|
||||
if "Nora:" in raw:
|
||||
persona_text = raw.split("Nora:")[-1].strip()
|
||||
else:
|
||||
persona_text = raw.strip()
|
||||
|
||||
# Now apply a simple post‐filter: remove any long spans that match exact sequences in the book corpus.
|
||||
# This is optional but helps ensure she didn’t copy large chunks verbatim. We check for 6+ character substrings
|
||||
# appearing more than once in her output.
|
||||
def remove_long_quotes(text: str) -> str:
|
||||
filtered = text
|
||||
# find any substring of length ≥6 that appears twice; we’ll just guess she’s quoting if it’s repeated.
|
||||
for match in re.finditer(r"\b[\w',]{6,}\b", text):
|
||||
substr = match.group(0)
|
||||
if filtered.count(substr) > 1:
|
||||
filtered = filtered.replace(substr, "[…]")
|
||||
return filtered
|
||||
|
||||
persona_text = remove_long_quotes(persona_text)
|
||||
return persona_text
|
||||
|
||||
|
||||
def ensure_persona_file(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str):
|
||||
"""
|
||||
If nora_persona.txt does not exist, generate one (ensuring originality).
|
||||
"""
|
||||
if os.path.isfile(PERSONA_PATH):
|
||||
return
|
||||
|
||||
print("[persona] No persona found. Generating a new, original persona…")
|
||||
persona_text = asyncio.run(generate_persona(model, tokenizer, device))
|
||||
|
||||
# Save to disk
|
||||
with open(PERSONA_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(persona_text)
|
||||
print(f"[persona] Wrote new persona to {PERSONA_PATH}.")
|
||||
|
||||
|
||||
async def maybe_update_persona(model: NoraTransformerLM, tokenizer: CharTokenizer, device: str):
|
||||
"""
|
||||
Regenerate Nora’s persona if she requests it, overwriting the file.
|
||||
"""
|
||||
print("[persona] Updating persona at Nora's request…")
|
||||
persona_text = await generate_persona(model, tokenizer, device)
|
||||
with open(PERSONA_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(persona_text)
|
||||
print(f"[persona] Updated persona in {PERSONA_PATH}.")
|
||||
return persona_text
|
89
pretrain.py
Normal file
89
pretrain.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""
|
||||
pretrain.py
|
||||
|
||||
Orchestrates the entire Nora project:
|
||||
- Parses arguments
|
||||
- Builds or loads tokenizer
|
||||
- Constructs dataset & dataloader
|
||||
- Instantiates the model
|
||||
- Sets up optimizer, scheduler
|
||||
- Calls train()
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import logging
|
||||
from config import get_config
|
||||
from tokenizer import CharTokenizer
|
||||
from data_loader import get_dataloader
|
||||
from model import NoraTransformerLM
|
||||
from train import train
|
||||
from utils import setup_logging, load_checkpoint, save_checkpoint
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
|
||||
def pretrain():
|
||||
args = get_config()
|
||||
|
||||
# 1) Logging setup
|
||||
log_file = os.path.join(args.checkpoint_dir, "train.log")
|
||||
setup_logging(log_file)
|
||||
|
||||
logging.info(f"[pretrain] Using device: {args.device}")
|
||||
logging.info(f"[pretrain] Config: {args}")
|
||||
|
||||
# 2) Tokenizer: if vocab exists, load; else build from data_dir
|
||||
tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
|
||||
|
||||
# 3) DataLoader
|
||||
dataloader = get_dataloader(
|
||||
data_dir=args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
seq_length=args.seq_length,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
# 4) Model instantiation
|
||||
model = NoraTransformerLM(
|
||||
vocab_size=tokenizer.vocab_size(),
|
||||
d_model=args.d_model,
|
||||
nhead=args.nhead,
|
||||
num_layers=args.num_layers,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
dropout=args.dropout,
|
||||
max_seq_len=args.seq_length,
|
||||
)
|
||||
|
||||
# 5) Optimizer & scheduler (linear warmup + decay)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
|
||||
|
||||
def lr_lambda(current_step):
|
||||
# Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
|
||||
if current_step < args.warmup_steps:
|
||||
return float(current_step) / float(max(1, args.warmup_steps))
|
||||
return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
# 6) Check for existing checkpoint to resume
|
||||
start_step = 0
|
||||
ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
|
||||
ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
|
||||
if ckpts:
|
||||
latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
|
||||
logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
|
||||
start_step = load_checkpoint(latest_ckpt, model, optimizer)
|
||||
|
||||
# 7) Begin training
|
||||
try:
|
||||
train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
|
||||
except Exception as e:
|
||||
logging.exception("[main] Exception during training")
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pretrain()
|
23
requirements.txt
Normal file
23
requirements.txt
Normal file
@ -0,0 +1,23 @@
|
||||
# Core ML framework
|
||||
|
||||
|
||||
# Data loading / utilities
|
||||
tqdm>=4.64.0
|
||||
numpy>=1.22.0
|
||||
|
||||
# HuggingFace Datasets (PersonaChat, etc.)
|
||||
datasets>=2.0.0
|
||||
|
||||
# ConvoKit (for Cornell Movie-Dialogs + Unsloth)
|
||||
convokit>=2.0.0
|
||||
|
||||
# Discord bot
|
||||
discord.py>=2.0.0
|
||||
|
||||
# Web scraping
|
||||
beautifulsoup4>=4.11.0
|
||||
requests>=2.28.0
|
||||
|
||||
# If ConvoKit pulls in TensorFlow via Unsloth
|
||||
tensorflow>=2.10.0
|
||||
|
175
self_improve.py
Normal file
175
self_improve.py
Normal file
@ -0,0 +1,175 @@
|
||||
# self_improve.py
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import shutil
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from tokenizer import CharTokenizer
|
||||
from model import NoraTransformerLM
|
||||
from config import get_config
|
||||
|
||||
# ------------------------------------------------------
|
||||
# 1) “Teacher”: Pose a code‐generation prompt to Nora
|
||||
# ------------------------------------------------------
|
||||
def propose_patch(model, tokenizer, device, prompt: str) -> str:
|
||||
"""
|
||||
Ask Nora to generate a code snippet given `prompt`.
|
||||
e.g. prompt = "### FILE: knowledge_retriever.py\n# Add a new function clean_html(...) that..."
|
||||
Returns the raw text (possibly including the prompt).
|
||||
"""
|
||||
raw = model.generate(
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
prompt=prompt,
|
||||
max_length=512,
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
)
|
||||
return raw
|
||||
|
||||
|
||||
# ------------------------------------------------------
|
||||
# 2) “Verifier” agent: sandbox + test
|
||||
# ------------------------------------------------------
|
||||
class CodeVerifier:
|
||||
"""
|
||||
Given a proposed code patch (as text), this class:
|
||||
1. Writes it to a temporary file (or repo clone)
|
||||
2. Runs Python’s syntax check (compile) and unit tests
|
||||
3. Measures performance changes (e.g. run a small validation set through the model)
|
||||
4. Returns True/False + log messages
|
||||
"""
|
||||
|
||||
def __init__(self, repo_dir: str, test_command: str):
|
||||
"""
|
||||
- repo_dir: path to your Nora project root
|
||||
- test_command: a shell command string to run unit tests, e.g. "pytest tests/"
|
||||
"""
|
||||
self.repo_dir = repo_dir
|
||||
self.test_command = test_command
|
||||
|
||||
def verify_patch(self, rel_path: str, patch_code: str) -> bool:
|
||||
"""
|
||||
- rel_path: relative path inside repo where the patch should go, e.g. "knowledge_retriever.py"
|
||||
- patch_code: entire contents of that file (not a diff).
|
||||
Returns True if syntax + tests pass; False otherwise.
|
||||
"""
|
||||
# 1) Copy repo => temp dir
|
||||
tmpdir = tempfile.mkdtemp(prefix="nora_verify_")
|
||||
try:
|
||||
shutil.copytree(self.repo_dir, os.path.join(tmpdir, "repo"), dirs_exist_ok=True)
|
||||
target_file = os.path.join(tmpdir, "repo", rel_path)
|
||||
|
||||
# 2) Write patch_code to target_file
|
||||
with open(target_file, "w", encoding="utf-8") as f:
|
||||
f.write(patch_code)
|
||||
|
||||
# 3) Syntax check (try compiling)
|
||||
try:
|
||||
compile(patch_code, target_file, "exec")
|
||||
except SyntaxError as se:
|
||||
logging.error(f"Syntax error in patch: {se}")
|
||||
return False
|
||||
|
||||
# 4) Run unit tests
|
||||
result = subprocess.run(
|
||||
self.test_command,
|
||||
shell=True,
|
||||
cwd=os.path.join(tmpdir, "repo"),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logging.error(f"Unit tests failed:\n{result.stdout}")
|
||||
return False
|
||||
|
||||
# 5) (Optional) Performance check
|
||||
# You could load the updated model and measure perplexity on a tiny validation set here.
|
||||
# For now, we assume passing tests = “improvement.”
|
||||
|
||||
return True
|
||||
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def merge_patch(self, rel_path: str, patch_code: str) -> None:
|
||||
"""
|
||||
Overwrite the real file in `repo_dir/rel_path` with patch_code,
|
||||
then git-add and git-commit (you can also automate a PR).
|
||||
"""
|
||||
target_file = os.path.join(self.repo_dir, rel_path)
|
||||
with open(target_file, "w", encoding="utf-8") as f:
|
||||
f.write(patch_code)
|
||||
|
||||
# Example: git add + commit
|
||||
subprocess.run(f"git add {rel_path}", shell=True, cwd=self.repo_dir)
|
||||
subprocess.run(
|
||||
f'git commit -m "Auto-update {rel_path} via Nora self-improve."',
|
||||
shell=True,
|
||||
cwd=self.repo_dir,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------
|
||||
# 3) Main loop: ask → verify → merge (if good) → retrain
|
||||
# ------------------------------------------------------
|
||||
def self_improvement_cycle(repo_dir: str, device: str):
|
||||
"""
|
||||
Example cycle:
|
||||
1) Nora proposes a new helper in knowledge_retriever.py
|
||||
2) Verifier checks syntax + tests
|
||||
3) If ok, merge and trigger incremental retraining
|
||||
"""
|
||||
config = get_config()
|
||||
tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
|
||||
model = NoraTransformerLM(
|
||||
vocab_size=tokenizer.vocab_size(),
|
||||
d_model=config.d_model,
|
||||
nhead=config.nhead,
|
||||
num_layers=config.num_layers,
|
||||
dim_feedforward=config.dim_feedforward,
|
||||
dropout=config.dropout,
|
||||
max_seq_len=config.seq_length,
|
||||
)
|
||||
# Load latest checkpoint
|
||||
ckpts = []
|
||||
if os.path.isdir(config.checkpoint_dir):
|
||||
ckpts = [
|
||||
f
|
||||
for f in os.listdir(config.checkpoint_dir)
|
||||
if f.startswith("nora_step_") and f.endswith(".pt")
|
||||
]
|
||||
if ckpts:
|
||||
latest = sorted(ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1]
|
||||
state = torch.load(os.path.join(config.checkpoint_dir, latest), map_location="cpu")
|
||||
model.load_state_dict(state["model_state_dict"])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
verifier = CodeVerifier(repo_dir=repo_dir, test_command="pytest tests/")
|
||||
|
||||
# Example prompt: ask Nora to extend knowledge_retriever.py
|
||||
prompt = (
|
||||
"### FILE: knowledge_retriever.py\n"
|
||||
"# Add a function clean_html(html: str) -> str that strips tags and scripts.\n"
|
||||
"# Use BeautifulSoup if available. Return plain text.\n\n"
|
||||
"### START\n"
|
||||
"def clean_html(html: str) -> str:\n"
|
||||
)
|
||||
raw_patch = propose_patch(model, tokenizer, device, prompt)
|
||||
|
||||
# Extract everything from “def clean_html” to end of function (simple heuristic)
|
||||
# In practice, you’d parse until the next “\n\n” or rely on indentation.
|
||||
patch_code = raw_patch # for now, assume raw_patch is the full file contents
|
||||
|
||||
# Verify
|
||||
if verifier.verify_patch("knowledge_retriever.py", patch_code):
|
||||
logging.info("Patch verified. Merging into live code.")
|
||||
verifier.merge_patch("knowledge_retriever.py", patch_code)
|
||||
# Optionally: trigger incremental retraining here (e.g. call train.py with --resume)
|
||||
else:
|
||||
logging.warning("Patch failed verification. Discarding.")
|
7
train.py
7
train.py
@ -37,6 +37,13 @@ def train(
|
||||
|
||||
device = config.device
|
||||
model.to(device)
|
||||
|
||||
# ─── ensure optimizer state is on the same device ───
|
||||
# (this moves any loaded CPU buffers for Adam/AdamW into CUDA)
|
||||
for state in optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(device)
|
||||
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<pad>"])
|
||||
scaler = GradScaler()
|
||||
|
||||
|
Reference in New Issue
Block a user