176 lines
6.4 KiB
Python
176 lines
6.4 KiB
Python
# self_improve.py
|
||
|
||
import os
|
||
import subprocess
|
||
import tempfile
|
||
import shutil
|
||
import logging
|
||
|
||
import torch
|
||
from tokenizer import CharTokenizer
|
||
from model import NoraTransformerLM
|
||
from config import get_config
|
||
|
||
# ------------------------------------------------------
|
||
# 1) “Teacher”: Pose a code‐generation prompt to Nora
|
||
# ------------------------------------------------------
|
||
def propose_patch(model, tokenizer, device, prompt: str) -> str:
|
||
"""
|
||
Ask Nora to generate a code snippet given `prompt`.
|
||
e.g. prompt = "### FILE: knowledge_retriever.py\n# Add a new function clean_html(...) that..."
|
||
Returns the raw text (possibly including the prompt).
|
||
"""
|
||
raw = model.generate(
|
||
tokenizer=tokenizer,
|
||
device=device,
|
||
prompt=prompt,
|
||
max_length=512,
|
||
temperature=0.7,
|
||
top_k=50,
|
||
)
|
||
return raw
|
||
|
||
|
||
# ------------------------------------------------------
|
||
# 2) “Verifier” agent: sandbox + test
|
||
# ------------------------------------------------------
|
||
class CodeVerifier:
|
||
"""
|
||
Given a proposed code patch (as text), this class:
|
||
1. Writes it to a temporary file (or repo clone)
|
||
2. Runs Python’s syntax check (compile) and unit tests
|
||
3. Measures performance changes (e.g. run a small validation set through the model)
|
||
4. Returns True/False + log messages
|
||
"""
|
||
|
||
def __init__(self, repo_dir: str, test_command: str):
|
||
"""
|
||
- repo_dir: path to your Nora project root
|
||
- test_command: a shell command string to run unit tests, e.g. "pytest tests/"
|
||
"""
|
||
self.repo_dir = repo_dir
|
||
self.test_command = test_command
|
||
|
||
def verify_patch(self, rel_path: str, patch_code: str) -> bool:
|
||
"""
|
||
- rel_path: relative path inside repo where the patch should go, e.g. "knowledge_retriever.py"
|
||
- patch_code: entire contents of that file (not a diff).
|
||
Returns True if syntax + tests pass; False otherwise.
|
||
"""
|
||
# 1) Copy repo => temp dir
|
||
tmpdir = tempfile.mkdtemp(prefix="nora_verify_")
|
||
try:
|
||
shutil.copytree(self.repo_dir, os.path.join(tmpdir, "repo"), dirs_exist_ok=True)
|
||
target_file = os.path.join(tmpdir, "repo", rel_path)
|
||
|
||
# 2) Write patch_code to target_file
|
||
with open(target_file, "w", encoding="utf-8") as f:
|
||
f.write(patch_code)
|
||
|
||
# 3) Syntax check (try compiling)
|
||
try:
|
||
compile(patch_code, target_file, "exec")
|
||
except SyntaxError as se:
|
||
logging.error(f"Syntax error in patch: {se}")
|
||
return False
|
||
|
||
# 4) Run unit tests
|
||
result = subprocess.run(
|
||
self.test_command,
|
||
shell=True,
|
||
cwd=os.path.join(tmpdir, "repo"),
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.STDOUT,
|
||
text=True,
|
||
)
|
||
if result.returncode != 0:
|
||
logging.error(f"Unit tests failed:\n{result.stdout}")
|
||
return False
|
||
|
||
# 5) (Optional) Performance check
|
||
# You could load the updated model and measure perplexity on a tiny validation set here.
|
||
# For now, we assume passing tests = “improvement.”
|
||
|
||
return True
|
||
|
||
finally:
|
||
shutil.rmtree(tmpdir)
|
||
|
||
def merge_patch(self, rel_path: str, patch_code: str) -> None:
|
||
"""
|
||
Overwrite the real file in `repo_dir/rel_path` with patch_code,
|
||
then git-add and git-commit (you can also automate a PR).
|
||
"""
|
||
target_file = os.path.join(self.repo_dir, rel_path)
|
||
with open(target_file, "w", encoding="utf-8") as f:
|
||
f.write(patch_code)
|
||
|
||
# Example: git add + commit
|
||
subprocess.run(f"git add {rel_path}", shell=True, cwd=self.repo_dir)
|
||
subprocess.run(
|
||
f'git commit -m "Auto-update {rel_path} via Nora self-improve."',
|
||
shell=True,
|
||
cwd=self.repo_dir,
|
||
)
|
||
|
||
|
||
# ------------------------------------------------------
|
||
# 3) Main loop: ask → verify → merge (if good) → retrain
|
||
# ------------------------------------------------------
|
||
def self_improvement_cycle(repo_dir: str, device: str):
|
||
"""
|
||
Example cycle:
|
||
1) Nora proposes a new helper in knowledge_retriever.py
|
||
2) Verifier checks syntax + tests
|
||
3) If ok, merge and trigger incremental retraining
|
||
"""
|
||
config = get_config()
|
||
tokenizer = CharTokenizer(vocab_path=config.vocab_path, data_dir=config.data_dir)
|
||
model = NoraTransformerLM(
|
||
vocab_size=tokenizer.vocab_size(),
|
||
d_model=config.d_model,
|
||
nhead=config.nhead,
|
||
num_layers=config.num_layers,
|
||
dim_feedforward=config.dim_feedforward,
|
||
dropout=config.dropout,
|
||
max_seq_len=config.seq_length,
|
||
)
|
||
# Load latest checkpoint
|
||
ckpts = []
|
||
if os.path.isdir(config.checkpoint_dir):
|
||
ckpts = [
|
||
f
|
||
for f in os.listdir(config.checkpoint_dir)
|
||
if f.startswith("nora_step_") and f.endswith(".pt")
|
||
]
|
||
if ckpts:
|
||
latest = sorted(ckpts, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1]
|
||
state = torch.load(os.path.join(config.checkpoint_dir, latest), map_location="cpu")
|
||
model.load_state_dict(state["model_state_dict"])
|
||
model.to(device)
|
||
model.eval()
|
||
|
||
verifier = CodeVerifier(repo_dir=repo_dir, test_command="pytest tests/")
|
||
|
||
# Example prompt: ask Nora to extend knowledge_retriever.py
|
||
prompt = (
|
||
"### FILE: knowledge_retriever.py\n"
|
||
"# Add a function clean_html(html: str) -> str that strips tags and scripts.\n"
|
||
"# Use BeautifulSoup if available. Return plain text.\n\n"
|
||
"### START\n"
|
||
"def clean_html(html: str) -> str:\n"
|
||
)
|
||
raw_patch = propose_patch(model, tokenizer, device, prompt)
|
||
|
||
# Extract everything from “def clean_html” to end of function (simple heuristic)
|
||
# In practice, you’d parse until the next “\n\n” or rely on indentation.
|
||
patch_code = raw_patch # for now, assume raw_patch is the full file contents
|
||
|
||
# Verify
|
||
if verifier.verify_patch("knowledge_retriever.py", patch_code):
|
||
logging.info("Patch verified. Merging into live code.")
|
||
verifier.merge_patch("knowledge_retriever.py", patch_code)
|
||
# Optionally: trigger incremental retraining here (e.g. call train.py with --resume)
|
||
else:
|
||
logging.warning("Patch failed verification. Discarding.")
|