Files
Nora/main.py

311 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# main.py
import os
import asyncio
import logging
import subprocess
import re
from urllib.parse import urlparse
import discord
import torch
from discord import Intents
from config import get_config
from tokenizer import CharTokenizer
from model import NoraTransformerLM, top_k_logits
from data_loader import get_dataloader
from utils import setup_logging, load_checkpoint, save_checkpoint
from knowledge_retriever import fetch_url, clean_html, save_text
import persona_manager # <— PERSONA CHANGE
# Logging setup
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# Keep the SCRAPE regex as before
SCRAPE_PATTERN = re.compile(r"<<SCRAPE:(https?://[^\s>]+)>>", re.IGNORECASE)
# ---------------------------------------------------
# 1) Build or Reload Noras Model & Tokenizer
# ---------------------------------------------------
def build_nora(config, device):
"""
- Loads tokenizer
- Instantiates NoraTransformerLM
- Loads latest checkpoint (if any)
"""
tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
model = NoraTransformerLM(
vocab_size=tokenizer.vocab_size(),
d_model=config.d_model,
nhead=config.nhead,
num_layers=config.num_layers,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout,
max_seq_len=config.seq_length,
)
# Find latest checkpoint
latest_ckpt = None
if os.path.isdir(config.checkpoint_dir):
ckpts = [
f
for f in os.listdir(config.checkpoint_dir)
if f.startswith("nora_step_") and f.endswith(".pt")
]
if ckpts:
latest_ckpt = sorted(
ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
if latest_ckpt:
ckpt_path = os.path.join(config.checkpoint_dir, latest_ckpt)
state = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state["model_state_dict"])
logger.info(f"Loaded Nora from checkpoint {latest_ckpt}")
else:
logger.warning("No checkpoints found. Nora starts untrained.")
model.to(device)
model.eval()
return model, tokenizer
# ---------------------------------------------------
# 2) Autoregressive Sampling Function
# ---------------------------------------------------
def generate_text(
model: NoraTransformerLM,
tokenizer: CharTokenizer,
device: str,
prompt: str,
max_length: int = 128,
temperature: float = 1.0,
top_k: int = 50,
) -> str:
"""
Wrapper around model.generate. Returns raw generated string.
"""
return model.generate(
tokenizer=tokenizer,
device=device,
prompt=prompt,
max_length=max_length,
temperature=temperature,
top_k=top_k,
)
# ---------------------------------------------------
# 3) SCRAPE Directive Handler (same as before)
# ---------------------------------------------------
async def handle_scrape_directives(text: str, data_dir: str) -> bool:
urls = set(m.group(1) for m in SCRAPE_PATTERN.finditer(text))
if not urls:
return False
for url in urls:
logger.info(f"Directive: Scraping {url}")
html = fetch_url(url)
if not html:
logger.error(f"Failed to fetch {url}")
continue
plain = clean_html(html)
title = ""
m = re.search(r"<title>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
if m:
title = m.group(1).strip()[:50]
else:
title = urlparse(url).netloc
save_text(plain, title)
return True
# ---------------------------------------------------
# 4) Discord Client (“Nora” class) with Persona
# ---------------------------------------------------
class Nora(discord.Client):
def __init__(self, model, tokenizer, config, device):
intents = Intents.default()
intents.messages = True
intents.message_content = True
super().__init__(intents=intents)
self.model = model
self.tokenizer = tokenizer
self.config = config
self.device = device
# history per channel (last 5 user/nora pairs)
self.history = {}
# Load Noras persona from disk (its guaranteed to exist by startup)
with open(persona_manager.PERSONA_PATH, "r", encoding="utf-8") as f:
self.persona_text = f.read()
async def on_ready(self):
logger.info(f"Logged in as {self.user} (ID: {self.user.id})")
logger.info("Nora is online and ready to be herself.")
# Background task: reload model if a new checkpoint shows up
self.loop.create_task(self._reload_model_periodically())
async def _reload_model_periodically(self, interval: int = 600):
"""
Every `interval` seconds, check for newer checkpoint & reload.
"""
while True:
await asyncio.sleep(interval)
ckpts = [
f
for f in os.listdir(self.config.checkpoint_dir)
if f.startswith("nora_step_") and f.endswith(".pt")
]
if not ckpts:
continue
latest = sorted(
ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
ckpt_path = os.path.join(self.config.checkpoint_dir, latest)
state = torch.load(ckpt_path, map_location="cpu")
self.model.load_state_dict(state["model_state_dict"])
self.model.to(self.device)
self.model.eval()
logger.info(f"Reloaded Noras model from {latest}")
async def on_message(self, message: discord.Message):
# 1) Ignore self-messages
if message.author.id == self.user.id:
return
content = message.content.strip()
prompt = None
# 2) If in DM, treat entire content as prompt
if isinstance(message.channel, discord.DMChannel):
prompt = content
# Also allow “update persona” in DM to regenerate persona file
if content.lower().startswith("update persona"):
# Regenerate persona asynchronously
new_persona = await persona_manager.maybe_update_persona(
self.model, self.tokenizer, self.device
)
self.persona_text = new_persona
await message.channel.send(
"I have rewritten my persona. Thank you! ❤️"
)
return
# 3) Otherwise (guild), require mention or “nora,” prefix
else:
if self.user.mention in content:
prompt = content.replace(self.user.mention, "").strip()
elif content.lower().startswith("nora,"):
prompt = content[len("nora,"):].strip()
else:
return
if not prompt:
return # e.g. user only said “Nora,” with no text after
# 4) Show typing indicator if in a guild text channel
if isinstance(message.channel, discord.TextChannel):
await message.channel.trigger_typing()
# 5) Build the full prompt: persona + history + users prompt
chan_id = str(message.channel.id)
history = self.history.get(chan_id, [])
prompt_lines = []
# 5.1) Insert Noras persona (so she “speaks as herself”)
prompt_lines.append(self.persona_text)
prompt_lines.append("") # blank line
# 5.2) Insert up to the last 4 exchanges
for user_msg, nora_msg in history[-4:]:
prompt_lines.append(f"User: {user_msg}")
prompt_lines.append(f"Nora: {nora_msg}")
prompt_lines.append("")
# 5.3) Finally, insert the new user prompt
prompt_lines.append(f"User: {prompt}")
prompt_lines.append("Nora:")
conversation_prompt = "\n".join(prompt_lines)
# 6) Generate Noras reply (tighter sampling: temp=0.8, top_k=20)
try:
raw = await asyncio.to_thread(
self.model.generate,
self.tokenizer,
self.device,
conversation_prompt,
self.config.seq_length,
0.8, # temperature
20, # top_k
)
except Exception:
logger.exception("Error in Nora.generate()")
await message.channel.send("😔 Sorry, I hit an error trying to think.")
return
# 7) Extract just Noras reply text
if "Nora:" in raw:
nora_reply = raw.split("Nora:")[-1].strip()
else:
nora_reply = raw[len(conversation_prompt) :].strip()
# 8) Handle <<SCRAPE:…>> if present
did_scrape = await handle_scrape_directives(raw, self.config.data_dir)
if did_scrape:
logger.info("Detected scrape directive. Triggering incremental retrain.")
subprocess.Popen(
["python", "pretrain.py", "--resume"], # note: pretrain.py was your old main.py
cwd=os.getcwd(),
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
nora_reply += "\n\n*I found some new things and am updating myself…*"
# 9) Save to history and prune to the last 5
self.history.setdefault(chan_id, []).append((prompt, nora_reply))
self.history[chan_id] = self.history[chan_id][-5 :]
# 10) Send Noras reply
await message.channel.send(nora_reply)
# ---------------------------------------------------
# 4) Entrypoint
# ---------------------------------------------------
if __name__ == "__main__":
config = get_config()
device = config.device
# 4.1) Build/load model & tokenizer
model, tokenizer = build_nora(config, device)
# 4.2) Ensure a persona exists—if not, generate one now
persona_manager.ensure_persona_file(model, tokenizer, device)
# 4.3) After that, we can proceed to start the agent
discord_token = os.getenv("DISCORD_TOKEN")
if not discord_token:
logger.error("Please set DISCORD_TOKEN in your environment.")
exit(1)
# Enable CuDNN autotune if on CUDA
if device.startswith("cuda"):
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
bot = Nora(model=model, tokenizer=tokenizer, config=config, device=device)
bot.run(discord_token)