Attempting to create a growing AI
This commit is contained in:
parent
ffcc60e205
commit
6f28a30268
19
config.py
19
config.py
@ -1,19 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Config:
|
||||
model_dim = int(os.getenv("MODEL_DIM", 256))
|
||||
num_layers = int(os.getenv("NUM_LAYERS", 4))
|
||||
num_heads = int(os.getenv("HEADS", 8))
|
||||
vocab_size = int(os.getenv("VOCAB_SIZE", 30000))
|
||||
context_size = int(os.getenv("CONTEXT_SIZE", 512))
|
||||
batch_size = int(os.getenv("BATCH_SIZE", 8))
|
||||
lr = float(os.getenv("LEARNING_RATE", 1e-4))
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
cfg = Config()
|
25
dashboard.py
Normal file
25
dashboard.py
Normal file
@ -0,0 +1,25 @@
|
||||
from flask import Flask, render_template_string
|
||||
from debug import DebugMonitor
|
||||
|
||||
debug = DebugMonitor()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def home():
|
||||
return render_template_string("""
|
||||
<html>
|
||||
<head><title>Ruby Debug Dashboard</title></head>
|
||||
<body>
|
||||
<h1>🧠 Ruby Live Debug</h1>
|
||||
<p><b>Last Dream:</b> {{ debug.last_dream }}</p>
|
||||
<p><b>Last Thought:</b> {{ debug.last_thought }}</p>
|
||||
<p><b>Last Loss:</b> {{ debug.last_loss }}</p>
|
||||
<p><b>Last Reply:</b> {{ debug.last_context }}</p>
|
||||
</body>
|
||||
</html>
|
||||
""", debug=debug)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(port=5000)
|
29
debug.py
Normal file
29
debug.py
Normal file
@ -0,0 +1,29 @@
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class DebugMonitor:
|
||||
def __init__(self):
|
||||
self.last_dream = ""
|
||||
self.last_thought = ""
|
||||
self.last_loss = 0.0
|
||||
self.last_context = ""
|
||||
|
||||
def log_dream(self, dream):
|
||||
self.last_dream = dream
|
||||
self._print("💤 Dream", dream)
|
||||
|
||||
def log_thought(self, thought):
|
||||
self.last_thought = thought
|
||||
self._print("💭 Thought", thought)
|
||||
|
||||
def log_loss(self, loss):
|
||||
self.last_loss = loss
|
||||
self._print("📉 Loss", f"{loss:.4f}")
|
||||
|
||||
def log_context(self, context):
|
||||
self.last_context = context
|
||||
self._print("📖 Context", context)
|
||||
|
||||
def _print(self, label, content):
|
||||
now = datetime.now().strftime("%H:%M:%S")
|
||||
print(f"[{now}] {label}: {content}")
|
29
dream.py
Normal file
29
dream.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import time
|
||||
from utils import sample_reply, update_model_vocab
|
||||
from debug import DebugMonitor
|
||||
|
||||
debug = DebugMonitor()
|
||||
|
||||
|
||||
def run_dream_loop(model, tokenizer, device, optimizer, train_step, interval=120):
|
||||
print("Ruby is dreaming...")
|
||||
while True:
|
||||
reply, loss = generate_dream(model, tokenizer, device, optimizer, train_step)
|
||||
print(f"[DREAM] {reply} (loss={loss:.4f})")
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
def generate_dream(model, tokenizer, device, optimizer, train_step):
|
||||
update_model_vocab(model, tokenizer)
|
||||
|
||||
prompt = "Ruby: "
|
||||
input_ids = tokenizer.encode(prompt, return_tensors=True, freeze=True).to(device)
|
||||
|
||||
reply = sample_reply(model, tokenizer, input_ids)
|
||||
training_text = f"User: What do you think?\nRuby: {reply}"
|
||||
|
||||
loss = train_step(model, optimizer, tokenizer, training_text, device)
|
||||
return reply, loss
|
||||
debug.log_dream(reply)
|
||||
debug.log_loss(loss)
|
4
feedback.py
Normal file
4
feedback.py
Normal file
@ -0,0 +1,4 @@
|
||||
def basic_self_feedback(reply, user_response):
|
||||
if user_response and len(user_response.strip()) > 1:
|
||||
return 1.0
|
||||
return -0.5
|
29
memory.py
Normal file
29
memory.py
Normal file
@ -0,0 +1,29 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class MemoryBuffer:
|
||||
def __init__(self, max_len=3, path="memory.json"):
|
||||
self.path = Path(path)
|
||||
self.max_len = max_len
|
||||
self.memory = []
|
||||
self.load()
|
||||
|
||||
def add(self, user_input, bot_reply):
|
||||
self.memory.append(f"User: {user_input}")
|
||||
self.memory.append(f"Bot: {bot_reply}")
|
||||
if len(self.memory) > self.max_len * 2:
|
||||
self.memory = self.memory[-self.max_len * 2:]
|
||||
self.save()
|
||||
|
||||
def get_context(self):
|
||||
return self.memory.copy()
|
||||
|
||||
def save(self):
|
||||
with open(self.path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.memory, f)
|
||||
|
||||
def load(self):
|
||||
if self.path.exists():
|
||||
with open(self.path, "r", encoding="utf-8") as f:
|
||||
self.memory = json.load(f)
|
30
model.py
Normal file
30
model.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MiniTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
d_model=256,
|
||||
n_heads=4,
|
||||
n_layers=4,
|
||||
max_seq_len=512):
|
||||
super().__init__()
|
||||
self.token_emb = nn.Embedding(vocab_size, d_model)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, d_model))
|
||||
self.layers = nn.ModuleList([
|
||||
nn.TransformerEncoderLayer(d_model=d_model,
|
||||
nhead=n_heads,
|
||||
batch_first=True)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.ln = nn.LayerNorm(d_model)
|
||||
self.head = nn.Linear(d_model, vocab_size)
|
||||
|
||||
def forward(self, x):
|
||||
B, T = x.size()
|
||||
x = self.token_emb(x) + self.pos_emb[:, :T]
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
x = self.ln(x)
|
||||
return self.head(x)
|
70
personality.py
Normal file
70
personality.py
Normal file
@ -0,0 +1,70 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
|
||||
class Personality:
|
||||
def __init__(self, path="personality.json"):
|
||||
self.path = Path(path)
|
||||
self.data = {
|
||||
"likes": [],
|
||||
"dislikes": [],
|
||||
"traits": [],
|
||||
"curiosities": []
|
||||
}
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
if self.path.exists():
|
||||
with open(self.path, "r", encoding="utf-8") as f:
|
||||
self.data.update(json.load(f))
|
||||
|
||||
def save(self):
|
||||
with open(self.path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.data, f, indent=2)
|
||||
|
||||
def learn_topic(self, text):
|
||||
words = [w.lower() for w in text.split()]
|
||||
for word in words:
|
||||
if word.isalpha() and word not in self.data["curiosities"]:
|
||||
self.data["curiosities"].append(word)
|
||||
self.save()
|
||||
|
||||
def choose_curiosity(self):
|
||||
if not self.data["curiosities"]:
|
||||
return None
|
||||
return random.choice(self.data["curiosities"])
|
||||
|
||||
def observe_input(self, message: str):
|
||||
text = message.lower()
|
||||
|
||||
# Learn likes
|
||||
if "i like" in text:
|
||||
word = text.split("i like", 1)[1].strip().split()[0]
|
||||
if word and word not in self.data["likes"]:
|
||||
self.data["likes"].append(word)
|
||||
self.save()
|
||||
|
||||
# Learn dislikes
|
||||
if "i hate" in text or "i don't like" in text:
|
||||
for phrase in ["i hate", "i don't like"]:
|
||||
if phrase in text:
|
||||
word = text.split(phrase, 1)[1].strip().split()[0]
|
||||
if word and word not in self.data["dislikes"]:
|
||||
self.data["dislikes"].append(word)
|
||||
self.save()
|
||||
|
||||
# Learn traits from compliments
|
||||
for trigger in ["you are", "you're", "ur"]:
|
||||
if trigger in text:
|
||||
fragment = text.split(trigger, 1)[1].strip().split()[0]
|
||||
if fragment and fragment not in self.data["traits"]:
|
||||
self.data["traits"].append(fragment)
|
||||
self.save()
|
||||
|
||||
def reflect(self) -> str:
|
||||
if not self.data["likes"] and not self.data["traits"]:
|
||||
return "I'm still figuring out who I am."
|
||||
likes = ', '.join(self.data["likes"][:3]) or "nothing yet"
|
||||
traits = ', '.join(self.data["traits"][:3]) or "no traits yet"
|
||||
return f"I'm starting to think I like {likes}. People have called me {traits}."
|
106
ruby.py
Normal file
106
ruby.py
Normal file
@ -0,0 +1,106 @@
|
||||
import discord
|
||||
import torch
|
||||
from debug import DebugMonitor
|
||||
from dream import run_dream_loop
|
||||
from model import MiniTransformer
|
||||
from train_step import online_train_step
|
||||
from tokenizer import ChildTokenizer
|
||||
from feedback import basic_self_feedback
|
||||
from memory import MemoryBuffer
|
||||
from utils import update_model_vocab, track_loss, sample_reply, sample_thought
|
||||
from personality import Personality
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(filename='ruby.log', level=logging.ERROR)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
TOKEN = os.getenv('DISCORD_TOKEN')
|
||||
|
||||
# Initialize personality
|
||||
personality = Personality()
|
||||
|
||||
# Initialize debug monitor
|
||||
debug = DebugMonitor()
|
||||
|
||||
# Initialize model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
tokenizer = ChildTokenizer()
|
||||
memory = MemoryBuffer(max_len=3)
|
||||
|
||||
model = MiniTransformer(vocab_size=tokenizer.vocab_size()).to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
|
||||
# Initialize Discord client
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = discord.Client(intents=intents)
|
||||
|
||||
# Start the dream loop in a separate thread
|
||||
dream_thread = threading.Thread(
|
||||
target=run_dream_loop,
|
||||
args=(model, tokenizer, device, optimizer, online_train_step),
|
||||
daemon=True
|
||||
)
|
||||
dream_thread.start()
|
||||
|
||||
|
||||
# Event handlers
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"{client.user} is ready and learning!")
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
try:
|
||||
# Ignore bot's own messages
|
||||
if message.author == client.user:
|
||||
return
|
||||
|
||||
# Get user input and memory
|
||||
user_input = message.content
|
||||
context = memory.get_context()
|
||||
full_input = ' '.join(context + [user_input])
|
||||
|
||||
# 🔍 Debug: log context
|
||||
debug.log_context(full_input)
|
||||
|
||||
# Ensure model matches tokenizer
|
||||
update_model_vocab(model, tokenizer)
|
||||
|
||||
# Encode user input
|
||||
input_ids = tokenizer.encode(full_input, return_tensors=True, freeze=True).to(device)
|
||||
if input_ids.size(1) < 2:
|
||||
return
|
||||
|
||||
# 💭 Generate internal thought
|
||||
thought = sample_thought(model, tokenizer, device, full_input)
|
||||
debug.log_thought(thought)
|
||||
|
||||
# 🗣️ Generate reply from Ruby
|
||||
reply = sample_reply(model, tokenizer, input_ids)
|
||||
debug.log_context(reply)
|
||||
|
||||
# ✅ Send the reply
|
||||
await message.channel.send(reply if reply.strip() else "...")
|
||||
|
||||
# Add to memory
|
||||
memory.add(user_input, reply)
|
||||
|
||||
# 📉 Train and log loss
|
||||
training_example = f"User: {user_input}\nRuby: {reply}"
|
||||
loss = online_train_step(model, optimizer, tokenizer, training_example, device)
|
||||
debug.log_loss(loss)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Error in on_message")
|
||||
await message.channel.send("Oops, I had a brain freeze.")
|
||||
|
||||
|
||||
client.run(TOKEN)
|
107920
tokenizer.json
107920
tokenizer.json
File diff suppressed because it is too large
Load Diff
42
tokenizer.py
Normal file
42
tokenizer.py
Normal file
@ -0,0 +1,42 @@
|
||||
import torch
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("tokenizer")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
fh = logging.FileHandler("learned_chars.log")
|
||||
formatter = logging.Formatter('%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
|
||||
class ChildTokenizer:
|
||||
def __init__(self):
|
||||
self.char_to_id = {'<pad>': 0, '<unk>': 1}
|
||||
self.id_to_char = {0: '<pad>', 1: '<unk>'}
|
||||
self.next_id = 2
|
||||
|
||||
# 🔤 Bootstrap with common characters
|
||||
for ch in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.,!? ':;":
|
||||
self.char_to_id[ch] = self.next_id
|
||||
self.id_to_char[self.next_id] = ch
|
||||
self.next_id += 1
|
||||
|
||||
def encode(self, text, return_tensors=False, freeze=False):
|
||||
ids = []
|
||||
for ch in text:
|
||||
if ch not in self.char_to_id:
|
||||
if freeze:
|
||||
ids.append(self.char_to_id.get('<unk>', 1))
|
||||
continue
|
||||
self.char_to_id[ch] = self.next_id
|
||||
self.id_to_char[self.next_id] = ch
|
||||
self.next_id += 1
|
||||
ids.append(self.char_to_id[ch])
|
||||
return torch.tensor([ids], dtype=torch.long) if return_tensors else ids
|
||||
|
||||
def decode(self, ids):
|
||||
return ''.join([self.id_to_char.get(i, '<unk>') for i in ids])
|
||||
|
||||
def vocab_size(self):
|
||||
return self.next_id
|
159
train.py
159
train.py
@ -1,159 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datasets import load_dataset
|
||||
from tokenizers import Tokenizer, models, trainers, decoders
|
||||
from config import cfg
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
|
||||
# 1. Tokenizer Implementation (Modified)
|
||||
class RubyTokenizer:
|
||||
def __init__(self):
|
||||
self.tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
|
||||
self.tokenizer.add_special_tokens(["[PAD]", "[UNK]"])
|
||||
self.tokenizer.decoder = decoders.ByteLevel()
|
||||
|
||||
def train(self, texts):
|
||||
trainer = trainers.BpeTrainer(
|
||||
special_tokens=["[PAD]", "[UNK]"],
|
||||
vocab_size=cfg.vocab_size,
|
||||
min_frequency=2, # Modified
|
||||
show_progress=True
|
||||
)
|
||||
self.tokenizer.train_from_iterator(
|
||||
(text.split() for text in texts), # Modified: better word handling
|
||||
trainer=trainer
|
||||
)
|
||||
|
||||
def encode(self, text):
|
||||
return self.tokenizer.encode(text).ids
|
||||
|
||||
@property
|
||||
def pad_id(self):
|
||||
return self.tokenizer.token_to_id("[PAD]") # Modified
|
||||
|
||||
|
||||
# 2. Optimized Dataset (Modified padding handling)
|
||||
class CachedDataset(Dataset):
|
||||
def __init__(self):
|
||||
self.data = np.memmap("dataset_cache.bin",
|
||||
dtype=np.int32,
|
||||
mode="r",
|
||||
shape=(os.path.getsize("dataset_cache.bin")//4,))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data) // cfg.context_size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
start = idx * cfg.context_size
|
||||
return torch.from_numpy(self.data[start:start+cfg.context_size].copy())
|
||||
|
||||
|
||||
# 3. Transformer Model (Modified padding_idx)
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, pad_id):
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(
|
||||
cfg.vocab_size,
|
||||
cfg.model_dim,
|
||||
padding_idx=pad_id # Modified
|
||||
)
|
||||
self.blocks = nn.ModuleList([
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=cfg.model_dim,
|
||||
nhead=cfg.num_heads,
|
||||
dim_feedforward=cfg.model_dim*4,
|
||||
batch_first=True
|
||||
) for _ in range(cfg.num_layers)
|
||||
])
|
||||
self.head = nn.Linear(cfg.model_dim, cfg.vocab_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embed(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return self.head(x)
|
||||
|
||||
|
||||
# 4. Main Training Process (Critical fixes)
|
||||
def main():
|
||||
# Initialize tokenizer
|
||||
tokenizer = RubyTokenizer()
|
||||
|
||||
if not os.path.exists("dataset_cache.bin"):
|
||||
print("Creating dataset cache...")
|
||||
ds = load_dataset("openwebtext", split="train[:5%]")
|
||||
|
||||
# Train and save tokenizer (Modified)
|
||||
if not os.path.exists("tokenizer.json"):
|
||||
print("Training tokenizer...")
|
||||
tokenizer.train([text for text in ds["text"] if len(text) > 100])
|
||||
tokenizer.tokenizer.save("tokenizer.json")
|
||||
else:
|
||||
tokenizer.tokenizer = Tokenizer.from_file("tokenizer.json")
|
||||
|
||||
# Tokenize and cache data (Modified)
|
||||
all_tokens = []
|
||||
pad_id = tokenizer.pad_id
|
||||
|
||||
for text in ds["text"]:
|
||||
tokens = tokenizer.encode(text)
|
||||
tokens = tokens[:cfg.context_size] # Truncate after tokenization
|
||||
pad_len = cfg.context_size - len(tokens)
|
||||
all_tokens.extend(tokens + [pad_id]*pad_len) # Modified
|
||||
|
||||
memmap = np.memmap("dataset_cache.bin",
|
||||
dtype=np.int32,
|
||||
mode="w+",
|
||||
shape=(len(all_tokens),))
|
||||
memmap[:] = np.array(all_tokens, dtype=np.int32)
|
||||
del memmap
|
||||
|
||||
# Test tokenizer (Modified)
|
||||
test_text = "The quick brown fox jumps over the lazy dog."
|
||||
print("Tokenizer test:", tokenizer.tokenizer.encode(test_text).tokens)
|
||||
|
||||
# Initialize model with pad_id (Modified)
|
||||
model = Transformer(pad_id=tokenizer.pad_id).to(cfg.device)
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
|
||||
scaler = GradScaler()
|
||||
|
||||
dataset = CachedDataset()
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
pin_memory=True,
|
||||
shuffle=True)
|
||||
|
||||
# Training loop (Modified loss calculation)
|
||||
start = time.time()
|
||||
for step, batch in enumerate(loader):
|
||||
batch = batch.to(cfg.device, non_blocking=True)
|
||||
|
||||
inputs = batch[:, :-1]
|
||||
targets = batch[:, 1:]
|
||||
|
||||
with autocast():
|
||||
outputs = model(inputs)
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
outputs.reshape(-1, cfg.vocab_size),
|
||||
targets.reshape(-1).long(),
|
||||
ignore_index=tokenizer.pad_id # Modified
|
||||
)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
opt.zero_grad()
|
||||
|
||||
if step % 10 == 0:
|
||||
elapsed = time.time() - start
|
||||
speed = (step + 1) * cfg.batch_size / elapsed
|
||||
print(f"Step {step} | Loss: {loss.item():.4f} | Speed: {speed:.1f} samples/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
34
train_step.py
Normal file
34
train_step.py
Normal file
@ -0,0 +1,34 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from utils import update_model_vocab
|
||||
|
||||
|
||||
def online_train_step(model, optimizer, tokenizer, message, device):
|
||||
# Ensure model can handle current vocab
|
||||
update_model_vocab(model, tokenizer)
|
||||
|
||||
# Freeze tokenizer so it doesn't grow mid-train
|
||||
tokens = tokenizer.encode(message, return_tensors=True, freeze=True).to(device)
|
||||
if tokens.size(1) < 2:
|
||||
return 0.0
|
||||
|
||||
# Truncate long input
|
||||
max_len = model.pos_emb.size(1)
|
||||
if tokens.size(1) > max_len:
|
||||
tokens = tokens[:, -max_len:]
|
||||
|
||||
x = tokens[:, :-1]
|
||||
y = tokens[:, 1:]
|
||||
|
||||
# HARD STOP if y exceeds model vocab
|
||||
vocab_size = model.token_emb.num_embeddings
|
||||
assert y.max().item() < vocab_size, f"y contains token > vocab_size ({y.max().item()} >= {vocab_size})"
|
||||
|
||||
logits = model(x)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return loss.item()
|
86
utils.py
Normal file
86
utils.py
Normal file
@ -0,0 +1,86 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch.nn.functional import softmax
|
||||
|
||||
loss_log = []
|
||||
|
||||
|
||||
def track_loss(loss):
|
||||
loss_log.append(loss)
|
||||
if len(loss_log) % 50 == 0:
|
||||
plot_loss()
|
||||
|
||||
|
||||
def plot_loss():
|
||||
plt.figure()
|
||||
plt.plot(loss_log)
|
||||
plt.title("Training Loss Over Time")
|
||||
plt.xlabel("Steps")
|
||||
plt.ylabel("Loss")
|
||||
plt.savefig("loss_plot.png")
|
||||
plt.close()
|
||||
|
||||
|
||||
def update_model_vocab(model, tokenizer):
|
||||
new_vocab = tokenizer.vocab_size()
|
||||
old_vocab = model.token_emb.num_embeddings
|
||||
d_model = model.token_emb.embedding_dim
|
||||
|
||||
if new_vocab > old_vocab:
|
||||
# Resize token embedding
|
||||
old_weights = model.token_emb.weight.data
|
||||
new_emb = torch.nn.Embedding(new_vocab, d_model).to(old_weights.device)
|
||||
new_emb.weight.data[:old_vocab] = old_weights
|
||||
torch.nn.init.normal_(new_emb.weight.data[old_vocab:], mean=0.0, std=0.02)
|
||||
model.token_emb = new_emb
|
||||
|
||||
# Resize output head
|
||||
old_head = model.head
|
||||
new_head = torch.nn.Linear(d_model, new_vocab).to(old_weights.device)
|
||||
new_head.weight.data[:old_vocab] = old_head.weight.data
|
||||
new_head.bias.data[:old_vocab] = old_head.bias.data
|
||||
torch.nn.init.normal_(new_head.weight.data[old_vocab:], mean=0.0, std=0.02)
|
||||
torch.nn.init.zeros_(new_head.bias.data[old_vocab:])
|
||||
model.head = new_head
|
||||
|
||||
|
||||
def sample_reply(model, tokenizer, input_ids, max_len=40):
|
||||
model.eval()
|
||||
generated = input_ids.clone()
|
||||
device = input_ids.device
|
||||
|
||||
for _ in range(max_len):
|
||||
# Truncate input to fit positional embedding
|
||||
max_seq = model.pos_emb.size(1)
|
||||
if generated.size(1) > max_seq:
|
||||
generated = generated[:, -max_seq:]
|
||||
|
||||
update_model_vocab(model, tokenizer)
|
||||
|
||||
try:
|
||||
logits = model(generated)
|
||||
except RuntimeError as e:
|
||||
print("CUDA crash in sample_reply — possible vocab mismatch")
|
||||
print("Generated:", generated)
|
||||
raise e
|
||||
|
||||
next_token_logits = logits[0, -1, :]
|
||||
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
|
||||
next_token = probs.argmax(dim=-1, keepdim=True)
|
||||
next_token = next_token.unsqueeze(0) # Shape: [1, 1]
|
||||
generated = torch.cat((generated, next_token), dim=1)
|
||||
|
||||
decoded = tokenizer.decode([next_token.item()])
|
||||
if decoded in ['\n', '.', '!', '?']:
|
||||
break
|
||||
|
||||
output = generated[0].tolist()[input_ids.shape[1]:]
|
||||
reply = tokenizer.decode(output).strip()
|
||||
print(f"[Reply] {repr(reply)}")
|
||||
return reply
|
||||
|
||||
|
||||
def sample_thought(model, tokenizer, device, context_text, max_len=60):
|
||||
prompt = f"[thinking] {context_text}"
|
||||
input_ids = tokenizer.encode(prompt, return_tensors=True).to(device)
|
||||
return sample_reply(model, tokenizer, input_ids, max_len=max_len)
|
Loading…
x
Reference in New Issue
Block a user