Added another learning source for Nora. Also added the requirements.

This commit is contained in:
2025-06-09 14:25:11 -04:00
parent da23742671
commit 5d53ba7cb8
14 changed files with 1070 additions and 78 deletions

359
main.py
View File

@ -1,89 +1,310 @@
"""
main.py
Orchestrates the entire Nora project:
- Parses arguments
- Builds or loads tokenizer
- Constructs dataset & dataloader
- Instantiates the model
- Sets up optimizer, scheduler
- Calls train()
"""
# main.py
import os
import torch
import asyncio
import logging
import subprocess
import re
from urllib.parse import urlparse
import discord
import torch
from discord import Intents
from config import get_config
from tokenizer import CharTokenizer
from model import NoraTransformerLM, top_k_logits
from data_loader import get_dataloader
from model import NoraTransformerLM
from train import train
from utils import setup_logging, load_checkpoint, save_checkpoint
from knowledge_retriever import fetch_url, clean_html, save_text
import persona_manager # <— PERSONA CHANGE
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
# Logging setup
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# Keep the SCRAPE regex as before
SCRAPE_PATTERN = re.compile(r"<<SCRAPE:(https?://[^\s>]+)>>", re.IGNORECASE)
def main():
args = get_config()
# ---------------------------------------------------
# 1) Build or Reload Noras Model & Tokenizer
# ---------------------------------------------------
def build_nora(config, device):
"""
- Loads tokenizer
- Instantiates NoraTransformerLM
- Loads latest checkpoint (if any)
"""
tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
# 1) Logging setup
log_file = os.path.join(args.checkpoint_dir, "train.log")
setup_logging(log_file)
logging.info(f"[main] Using device: {args.device}")
logging.info(f"[main] Config: {args}")
# 2) Tokenizer: if vocab exists, load; else build from data_dir
tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
# 3) DataLoader
dataloader = get_dataloader(
data_dir=args.data_dir,
tokenizer=tokenizer,
seq_length=args.seq_length,
batch_size=args.batch_size,
shuffle=True,
)
# 4) Model instantiation
model = NoraTransformerLM(
vocab_size=tokenizer.vocab_size(),
d_model=args.d_model,
nhead=args.nhead,
num_layers=args.num_layers,
dim_feedforward=args.dim_feedforward,
dropout=args.dropout,
max_seq_len=args.seq_length,
d_model=config.d_model,
nhead=config.nhead,
num_layers=config.num_layers,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout,
max_seq_len=config.seq_length,
)
# 5) Optimizer & scheduler (linear warmup + decay)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
# Find latest checkpoint
latest_ckpt = None
if os.path.isdir(config.checkpoint_dir):
ckpts = [
f
for f in os.listdir(config.checkpoint_dir)
if f.startswith("nora_step_") and f.endswith(".pt")
]
if ckpts:
latest_ckpt = sorted(
ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
def lr_lambda(current_step):
# Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
if current_step < args.warmup_steps:
return float(current_step) / float(max(1, args.warmup_steps))
return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
if latest_ckpt:
ckpt_path = os.path.join(config.checkpoint_dir, latest_ckpt)
state = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state["model_state_dict"])
logger.info(f"Loaded Nora from checkpoint {latest_ckpt}")
else:
logger.warning("No checkpoints found. Nora starts untrained.")
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# 6) Check for existing checkpoint to resume
start_step = 0
ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
if ckpts:
latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
start_step = load_checkpoint(latest_ckpt, model, optimizer)
# 7) Begin training
try:
train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
except Exception as e:
logging.exception("[main] Exception during training")
raise e
model.to(device)
model.eval()
return model, tokenizer
# ---------------------------------------------------
# 2) Autoregressive Sampling Function
# ---------------------------------------------------
def generate_text(
model: NoraTransformerLM,
tokenizer: CharTokenizer,
device: str,
prompt: str,
max_length: int = 128,
temperature: float = 1.0,
top_k: int = 50,
) -> str:
"""
Wrapper around model.generate. Returns raw generated string.
"""
return model.generate(
tokenizer=tokenizer,
device=device,
prompt=prompt,
max_length=max_length,
temperature=temperature,
top_k=top_k,
)
# ---------------------------------------------------
# 3) SCRAPE Directive Handler (same as before)
# ---------------------------------------------------
async def handle_scrape_directives(text: str, data_dir: str) -> bool:
urls = set(m.group(1) for m in SCRAPE_PATTERN.finditer(text))
if not urls:
return False
for url in urls:
logger.info(f"Directive: Scraping {url}")
html = fetch_url(url)
if not html:
logger.error(f"Failed to fetch {url}")
continue
plain = clean_html(html)
title = ""
m = re.search(r"<title>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
if m:
title = m.group(1).strip()[:50]
else:
title = urlparse(url).netloc
save_text(plain, title)
return True
# ---------------------------------------------------
# 4) Discord Client (“Nora” class) with Persona
# ---------------------------------------------------
class Nora(discord.Client):
def __init__(self, model, tokenizer, config, device):
intents = Intents.default()
intents.messages = True
intents.message_content = True
super().__init__(intents=intents)
self.model = model
self.tokenizer = tokenizer
self.config = config
self.device = device
# history per channel (last 5 user/nora pairs)
self.history = {}
# Load Noras persona from disk (its guaranteed to exist by startup)
with open(persona_manager.PERSONA_PATH, "r", encoding="utf-8") as f:
self.persona_text = f.read()
async def on_ready(self):
logger.info(f"Logged in as {self.user} (ID: {self.user.id})")
logger.info("Nora is online and ready to be herself.")
# Background task: reload model if a new checkpoint shows up
self.loop.create_task(self._reload_model_periodically())
async def _reload_model_periodically(self, interval: int = 600):
"""
Every `interval` seconds, check for newer checkpoint & reload.
"""
while True:
await asyncio.sleep(interval)
ckpts = [
f
for f in os.listdir(self.config.checkpoint_dir)
if f.startswith("nora_step_") and f.endswith(".pt")
]
if not ckpts:
continue
latest = sorted(
ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
ckpt_path = os.path.join(self.config.checkpoint_dir, latest)
state = torch.load(ckpt_path, map_location="cpu")
self.model.load_state_dict(state["model_state_dict"])
self.model.to(self.device)
self.model.eval()
logger.info(f"Reloaded Noras model from {latest}")
async def on_message(self, message: discord.Message):
# 1) Ignore self-messages
if message.author.id == self.user.id:
return
content = message.content.strip()
prompt = None
# 2) If in DM, treat entire content as prompt
if isinstance(message.channel, discord.DMChannel):
prompt = content
# Also allow “update persona” in DM to regenerate persona file
if content.lower().startswith("update persona"):
# Regenerate persona asynchronously
new_persona = await persona_manager.maybe_update_persona(
self.model, self.tokenizer, self.device
)
self.persona_text = new_persona
await message.channel.send(
"I have rewritten my persona. Thank you! ❤️"
)
return
# 3) Otherwise (guild), require mention or “nora,” prefix
else:
if self.user.mention in content:
prompt = content.replace(self.user.mention, "").strip()
elif content.lower().startswith("nora,"):
prompt = content[len("nora,"):].strip()
else:
return
if not prompt:
return # e.g. user only said “Nora,” with no text after
# 4) Show typing indicator if in a guild text channel
if isinstance(message.channel, discord.TextChannel):
await message.channel.trigger_typing()
# 5) Build the full prompt: persona + history + users prompt
chan_id = str(message.channel.id)
history = self.history.get(chan_id, [])
prompt_lines = []
# 5.1) Insert Noras persona (so she “speaks as herself”)
prompt_lines.append(self.persona_text)
prompt_lines.append("") # blank line
# 5.2) Insert up to the last 4 exchanges
for user_msg, nora_msg in history[-4:]:
prompt_lines.append(f"User: {user_msg}")
prompt_lines.append(f"Nora: {nora_msg}")
prompt_lines.append("")
# 5.3) Finally, insert the new user prompt
prompt_lines.append(f"User: {prompt}")
prompt_lines.append("Nora:")
conversation_prompt = "\n".join(prompt_lines)
# 6) Generate Noras reply (tighter sampling: temp=0.8, top_k=20)
try:
raw = await asyncio.to_thread(
self.model.generate,
self.tokenizer,
self.device,
conversation_prompt,
self.config.seq_length,
0.8, # temperature
20, # top_k
)
except Exception:
logger.exception("Error in Nora.generate()")
await message.channel.send("😔 Sorry, I hit an error trying to think.")
return
# 7) Extract just Noras reply text
if "Nora:" in raw:
nora_reply = raw.split("Nora:")[-1].strip()
else:
nora_reply = raw[len(conversation_prompt) :].strip()
# 8) Handle <<SCRAPE:…>> if present
did_scrape = await handle_scrape_directives(raw, self.config.data_dir)
if did_scrape:
logger.info("Detected scrape directive. Triggering incremental retrain.")
subprocess.Popen(
["python", "pretrain.py", "--resume"], # note: pretrain.py was your old main.py
cwd=os.getcwd(),
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
nora_reply += "\n\n*I found some new things and am updating myself…*"
# 9) Save to history and prune to the last 5
self.history.setdefault(chan_id, []).append((prompt, nora_reply))
self.history[chan_id] = self.history[chan_id][-5 :]
# 10) Send Noras reply
await message.channel.send(nora_reply)
# ---------------------------------------------------
# 4) Entrypoint
# ---------------------------------------------------
if __name__ == "__main__":
main()
config = get_config()
device = config.device
# 4.1) Build/load model & tokenizer
model, tokenizer = build_nora(config, device)
# 4.2) Ensure a persona exists—if not, generate one now
persona_manager.ensure_persona_file(model, tokenizer, device)
# 4.3) After that, we can proceed to start the agent
discord_token = os.getenv("DISCORD_TOKEN")
if not discord_token:
logger.error("Please set DISCORD_TOKEN in your environment.")
exit(1)
# Enable CuDNN autotune if on CUDA
if device.startswith("cuda"):
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
bot = Nora(model=model, tokenizer=tokenizer, config=config, device=device)
bot.run(discord_token)