Attempting to create a growing AI
This commit is contained in:
parent
ffcc60e205
commit
6f28a30268
19
config.py
19
config.py
@ -1,19 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
model_dim = int(os.getenv("MODEL_DIM", 256))
|
|
||||||
num_layers = int(os.getenv("NUM_LAYERS", 4))
|
|
||||||
num_heads = int(os.getenv("HEADS", 8))
|
|
||||||
vocab_size = int(os.getenv("VOCAB_SIZE", 30000))
|
|
||||||
context_size = int(os.getenv("CONTEXT_SIZE", 512))
|
|
||||||
batch_size = int(os.getenv("BATCH_SIZE", 8))
|
|
||||||
lr = float(os.getenv("LEARNING_RATE", 1e-4))
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
|
|
||||||
cfg = Config()
|
|
25
dashboard.py
Normal file
25
dashboard.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from flask import Flask, render_template_string
|
||||||
|
from debug import DebugMonitor
|
||||||
|
|
||||||
|
debug = DebugMonitor()
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/")
|
||||||
|
def home():
|
||||||
|
return render_template_string("""
|
||||||
|
<html>
|
||||||
|
<head><title>Ruby Debug Dashboard</title></head>
|
||||||
|
<body>
|
||||||
|
<h1>🧠 Ruby Live Debug</h1>
|
||||||
|
<p><b>Last Dream:</b> {{ debug.last_dream }}</p>
|
||||||
|
<p><b>Last Thought:</b> {{ debug.last_thought }}</p>
|
||||||
|
<p><b>Last Loss:</b> {{ debug.last_loss }}</p>
|
||||||
|
<p><b>Last Reply:</b> {{ debug.last_context }}</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
""", debug=debug)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(port=5000)
|
29
debug.py
Normal file
29
debug.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class DebugMonitor:
|
||||||
|
def __init__(self):
|
||||||
|
self.last_dream = ""
|
||||||
|
self.last_thought = ""
|
||||||
|
self.last_loss = 0.0
|
||||||
|
self.last_context = ""
|
||||||
|
|
||||||
|
def log_dream(self, dream):
|
||||||
|
self.last_dream = dream
|
||||||
|
self._print("💤 Dream", dream)
|
||||||
|
|
||||||
|
def log_thought(self, thought):
|
||||||
|
self.last_thought = thought
|
||||||
|
self._print("💭 Thought", thought)
|
||||||
|
|
||||||
|
def log_loss(self, loss):
|
||||||
|
self.last_loss = loss
|
||||||
|
self._print("📉 Loss", f"{loss:.4f}")
|
||||||
|
|
||||||
|
def log_context(self, context):
|
||||||
|
self.last_context = context
|
||||||
|
self._print("📖 Context", context)
|
||||||
|
|
||||||
|
def _print(self, label, content):
|
||||||
|
now = datetime.now().strftime("%H:%M:%S")
|
||||||
|
print(f"[{now}] {label}: {content}")
|
29
dream.py
Normal file
29
dream.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import torch
|
||||||
|
import time
|
||||||
|
from utils import sample_reply, update_model_vocab
|
||||||
|
from debug import DebugMonitor
|
||||||
|
|
||||||
|
debug = DebugMonitor()
|
||||||
|
|
||||||
|
|
||||||
|
def run_dream_loop(model, tokenizer, device, optimizer, train_step, interval=120):
|
||||||
|
print("Ruby is dreaming...")
|
||||||
|
while True:
|
||||||
|
reply, loss = generate_dream(model, tokenizer, device, optimizer, train_step)
|
||||||
|
print(f"[DREAM] {reply} (loss={loss:.4f})")
|
||||||
|
time.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dream(model, tokenizer, device, optimizer, train_step):
|
||||||
|
update_model_vocab(model, tokenizer)
|
||||||
|
|
||||||
|
prompt = "Ruby: "
|
||||||
|
input_ids = tokenizer.encode(prompt, return_tensors=True, freeze=True).to(device)
|
||||||
|
|
||||||
|
reply = sample_reply(model, tokenizer, input_ids)
|
||||||
|
training_text = f"User: What do you think?\nRuby: {reply}"
|
||||||
|
|
||||||
|
loss = train_step(model, optimizer, tokenizer, training_text, device)
|
||||||
|
return reply, loss
|
||||||
|
debug.log_dream(reply)
|
||||||
|
debug.log_loss(loss)
|
4
feedback.py
Normal file
4
feedback.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
def basic_self_feedback(reply, user_response):
|
||||||
|
if user_response and len(user_response.strip()) > 1:
|
||||||
|
return 1.0
|
||||||
|
return -0.5
|
29
memory.py
Normal file
29
memory.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryBuffer:
|
||||||
|
def __init__(self, max_len=3, path="memory.json"):
|
||||||
|
self.path = Path(path)
|
||||||
|
self.max_len = max_len
|
||||||
|
self.memory = []
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
def add(self, user_input, bot_reply):
|
||||||
|
self.memory.append(f"User: {user_input}")
|
||||||
|
self.memory.append(f"Bot: {bot_reply}")
|
||||||
|
if len(self.memory) > self.max_len * 2:
|
||||||
|
self.memory = self.memory[-self.max_len * 2:]
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def get_context(self):
|
||||||
|
return self.memory.copy()
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
with open(self.path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.memory, f)
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
if self.path.exists():
|
||||||
|
with open(self.path, "r", encoding="utf-8") as f:
|
||||||
|
self.memory = json.load(f)
|
30
model.py
Normal file
30
model.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class MiniTransformer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
vocab_size,
|
||||||
|
d_model=256,
|
||||||
|
n_heads=4,
|
||||||
|
n_layers=4,
|
||||||
|
max_seq_len=512):
|
||||||
|
super().__init__()
|
||||||
|
self.token_emb = nn.Embedding(vocab_size, d_model)
|
||||||
|
self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, d_model))
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
nn.TransformerEncoderLayer(d_model=d_model,
|
||||||
|
nhead=n_heads,
|
||||||
|
batch_first=True)
|
||||||
|
for _ in range(n_layers)
|
||||||
|
])
|
||||||
|
self.ln = nn.LayerNorm(d_model)
|
||||||
|
self.head = nn.Linear(d_model, vocab_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, T = x.size()
|
||||||
|
x = self.token_emb(x) + self.pos_emb[:, :T]
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
x = self.ln(x)
|
||||||
|
return self.head(x)
|
70
personality.py
Normal file
70
personality.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
class Personality:
|
||||||
|
def __init__(self, path="personality.json"):
|
||||||
|
self.path = Path(path)
|
||||||
|
self.data = {
|
||||||
|
"likes": [],
|
||||||
|
"dislikes": [],
|
||||||
|
"traits": [],
|
||||||
|
"curiosities": []
|
||||||
|
}
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
if self.path.exists():
|
||||||
|
with open(self.path, "r", encoding="utf-8") as f:
|
||||||
|
self.data.update(json.load(f))
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
with open(self.path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.data, f, indent=2)
|
||||||
|
|
||||||
|
def learn_topic(self, text):
|
||||||
|
words = [w.lower() for w in text.split()]
|
||||||
|
for word in words:
|
||||||
|
if word.isalpha() and word not in self.data["curiosities"]:
|
||||||
|
self.data["curiosities"].append(word)
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def choose_curiosity(self):
|
||||||
|
if not self.data["curiosities"]:
|
||||||
|
return None
|
||||||
|
return random.choice(self.data["curiosities"])
|
||||||
|
|
||||||
|
def observe_input(self, message: str):
|
||||||
|
text = message.lower()
|
||||||
|
|
||||||
|
# Learn likes
|
||||||
|
if "i like" in text:
|
||||||
|
word = text.split("i like", 1)[1].strip().split()[0]
|
||||||
|
if word and word not in self.data["likes"]:
|
||||||
|
self.data["likes"].append(word)
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
# Learn dislikes
|
||||||
|
if "i hate" in text or "i don't like" in text:
|
||||||
|
for phrase in ["i hate", "i don't like"]:
|
||||||
|
if phrase in text:
|
||||||
|
word = text.split(phrase, 1)[1].strip().split()[0]
|
||||||
|
if word and word not in self.data["dislikes"]:
|
||||||
|
self.data["dislikes"].append(word)
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
# Learn traits from compliments
|
||||||
|
for trigger in ["you are", "you're", "ur"]:
|
||||||
|
if trigger in text:
|
||||||
|
fragment = text.split(trigger, 1)[1].strip().split()[0]
|
||||||
|
if fragment and fragment not in self.data["traits"]:
|
||||||
|
self.data["traits"].append(fragment)
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def reflect(self) -> str:
|
||||||
|
if not self.data["likes"] and not self.data["traits"]:
|
||||||
|
return "I'm still figuring out who I am."
|
||||||
|
likes = ', '.join(self.data["likes"][:3]) or "nothing yet"
|
||||||
|
traits = ', '.join(self.data["traits"][:3]) or "no traits yet"
|
||||||
|
return f"I'm starting to think I like {likes}. People have called me {traits}."
|
106
ruby.py
Normal file
106
ruby.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
import discord
|
||||||
|
import torch
|
||||||
|
from debug import DebugMonitor
|
||||||
|
from dream import run_dream_loop
|
||||||
|
from model import MiniTransformer
|
||||||
|
from train_step import online_train_step
|
||||||
|
from tokenizer import ChildTokenizer
|
||||||
|
from feedback import basic_self_feedback
|
||||||
|
from memory import MemoryBuffer
|
||||||
|
from utils import update_model_vocab, track_loss, sample_reply, sample_thought
|
||||||
|
from personality import Personality
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(filename='ruby.log', level=logging.ERROR)
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
TOKEN = os.getenv('DISCORD_TOKEN')
|
||||||
|
|
||||||
|
# Initialize personality
|
||||||
|
personality = Personality()
|
||||||
|
|
||||||
|
# Initialize debug monitor
|
||||||
|
debug = DebugMonitor()
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
tokenizer = ChildTokenizer()
|
||||||
|
memory = MemoryBuffer(max_len=3)
|
||||||
|
|
||||||
|
model = MiniTransformer(vocab_size=tokenizer.vocab_size()).to(device)
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||||
|
|
||||||
|
# Initialize Discord client
|
||||||
|
intents = discord.Intents.default()
|
||||||
|
intents.message_content = True
|
||||||
|
client = discord.Client(intents=intents)
|
||||||
|
|
||||||
|
# Start the dream loop in a separate thread
|
||||||
|
dream_thread = threading.Thread(
|
||||||
|
target=run_dream_loop,
|
||||||
|
args=(model, tokenizer, device, optimizer, online_train_step),
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
dream_thread.start()
|
||||||
|
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
print(f"{client.user} is ready and learning!")
|
||||||
|
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_message(message):
|
||||||
|
try:
|
||||||
|
# Ignore bot's own messages
|
||||||
|
if message.author == client.user:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get user input and memory
|
||||||
|
user_input = message.content
|
||||||
|
context = memory.get_context()
|
||||||
|
full_input = ' '.join(context + [user_input])
|
||||||
|
|
||||||
|
# 🔍 Debug: log context
|
||||||
|
debug.log_context(full_input)
|
||||||
|
|
||||||
|
# Ensure model matches tokenizer
|
||||||
|
update_model_vocab(model, tokenizer)
|
||||||
|
|
||||||
|
# Encode user input
|
||||||
|
input_ids = tokenizer.encode(full_input, return_tensors=True, freeze=True).to(device)
|
||||||
|
if input_ids.size(1) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 💭 Generate internal thought
|
||||||
|
thought = sample_thought(model, tokenizer, device, full_input)
|
||||||
|
debug.log_thought(thought)
|
||||||
|
|
||||||
|
# 🗣️ Generate reply from Ruby
|
||||||
|
reply = sample_reply(model, tokenizer, input_ids)
|
||||||
|
debug.log_context(reply)
|
||||||
|
|
||||||
|
# ✅ Send the reply
|
||||||
|
await message.channel.send(reply if reply.strip() else "...")
|
||||||
|
|
||||||
|
# Add to memory
|
||||||
|
memory.add(user_input, reply)
|
||||||
|
|
||||||
|
# 📉 Train and log loss
|
||||||
|
training_example = f"User: {user_input}\nRuby: {reply}"
|
||||||
|
loss = online_train_step(model, optimizer, tokenizer, training_example, device)
|
||||||
|
debug.log_loss(loss)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("Error in on_message")
|
||||||
|
await message.channel.send("Oops, I had a brain freeze.")
|
||||||
|
|
||||||
|
|
||||||
|
client.run(TOKEN)
|
107920
tokenizer.json
107920
tokenizer.json
File diff suppressed because it is too large
Load Diff
42
tokenizer.py
Normal file
42
tokenizer.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("tokenizer")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
fh = logging.FileHandler("learned_chars.log")
|
||||||
|
formatter = logging.Formatter('%(message)s')
|
||||||
|
fh.setFormatter(formatter)
|
||||||
|
logger.addHandler(fh)
|
||||||
|
|
||||||
|
|
||||||
|
class ChildTokenizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.char_to_id = {'<pad>': 0, '<unk>': 1}
|
||||||
|
self.id_to_char = {0: '<pad>', 1: '<unk>'}
|
||||||
|
self.next_id = 2
|
||||||
|
|
||||||
|
# 🔤 Bootstrap with common characters
|
||||||
|
for ch in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.,!? ':;":
|
||||||
|
self.char_to_id[ch] = self.next_id
|
||||||
|
self.id_to_char[self.next_id] = ch
|
||||||
|
self.next_id += 1
|
||||||
|
|
||||||
|
def encode(self, text, return_tensors=False, freeze=False):
|
||||||
|
ids = []
|
||||||
|
for ch in text:
|
||||||
|
if ch not in self.char_to_id:
|
||||||
|
if freeze:
|
||||||
|
ids.append(self.char_to_id.get('<unk>', 1))
|
||||||
|
continue
|
||||||
|
self.char_to_id[ch] = self.next_id
|
||||||
|
self.id_to_char[self.next_id] = ch
|
||||||
|
self.next_id += 1
|
||||||
|
ids.append(self.char_to_id[ch])
|
||||||
|
return torch.tensor([ids], dtype=torch.long) if return_tensors else ids
|
||||||
|
|
||||||
|
def decode(self, ids):
|
||||||
|
return ''.join([self.id_to_char.get(i, '<unk>') for i in ids])
|
||||||
|
|
||||||
|
def vocab_size(self):
|
||||||
|
return self.next_id
|
159
train.py
159
train.py
@ -1,159 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import time
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from datasets import load_dataset
|
|
||||||
from tokenizers import Tokenizer, models, trainers, decoders
|
|
||||||
from config import cfg
|
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
|
||||||
|
|
||||||
|
|
||||||
# 1. Tokenizer Implementation (Modified)
|
|
||||||
class RubyTokenizer:
|
|
||||||
def __init__(self):
|
|
||||||
self.tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
|
|
||||||
self.tokenizer.add_special_tokens(["[PAD]", "[UNK]"])
|
|
||||||
self.tokenizer.decoder = decoders.ByteLevel()
|
|
||||||
|
|
||||||
def train(self, texts):
|
|
||||||
trainer = trainers.BpeTrainer(
|
|
||||||
special_tokens=["[PAD]", "[UNK]"],
|
|
||||||
vocab_size=cfg.vocab_size,
|
|
||||||
min_frequency=2, # Modified
|
|
||||||
show_progress=True
|
|
||||||
)
|
|
||||||
self.tokenizer.train_from_iterator(
|
|
||||||
(text.split() for text in texts), # Modified: better word handling
|
|
||||||
trainer=trainer
|
|
||||||
)
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
return self.tokenizer.encode(text).ids
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pad_id(self):
|
|
||||||
return self.tokenizer.token_to_id("[PAD]") # Modified
|
|
||||||
|
|
||||||
|
|
||||||
# 2. Optimized Dataset (Modified padding handling)
|
|
||||||
class CachedDataset(Dataset):
|
|
||||||
def __init__(self):
|
|
||||||
self.data = np.memmap("dataset_cache.bin",
|
|
||||||
dtype=np.int32,
|
|
||||||
mode="r",
|
|
||||||
shape=(os.path.getsize("dataset_cache.bin")//4,))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data) // cfg.context_size
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
start = idx * cfg.context_size
|
|
||||||
return torch.from_numpy(self.data[start:start+cfg.context_size].copy())
|
|
||||||
|
|
||||||
|
|
||||||
# 3. Transformer Model (Modified padding_idx)
|
|
||||||
class Transformer(nn.Module):
|
|
||||||
def __init__(self, pad_id):
|
|
||||||
super().__init__()
|
|
||||||
self.embed = nn.Embedding(
|
|
||||||
cfg.vocab_size,
|
|
||||||
cfg.model_dim,
|
|
||||||
padding_idx=pad_id # Modified
|
|
||||||
)
|
|
||||||
self.blocks = nn.ModuleList([
|
|
||||||
nn.TransformerEncoderLayer(
|
|
||||||
d_model=cfg.model_dim,
|
|
||||||
nhead=cfg.num_heads,
|
|
||||||
dim_feedforward=cfg.model_dim*4,
|
|
||||||
batch_first=True
|
|
||||||
) for _ in range(cfg.num_layers)
|
|
||||||
])
|
|
||||||
self.head = nn.Linear(cfg.model_dim, cfg.vocab_size)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.embed(x)
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
return self.head(x)
|
|
||||||
|
|
||||||
|
|
||||||
# 4. Main Training Process (Critical fixes)
|
|
||||||
def main():
|
|
||||||
# Initialize tokenizer
|
|
||||||
tokenizer = RubyTokenizer()
|
|
||||||
|
|
||||||
if not os.path.exists("dataset_cache.bin"):
|
|
||||||
print("Creating dataset cache...")
|
|
||||||
ds = load_dataset("openwebtext", split="train[:5%]")
|
|
||||||
|
|
||||||
# Train and save tokenizer (Modified)
|
|
||||||
if not os.path.exists("tokenizer.json"):
|
|
||||||
print("Training tokenizer...")
|
|
||||||
tokenizer.train([text for text in ds["text"] if len(text) > 100])
|
|
||||||
tokenizer.tokenizer.save("tokenizer.json")
|
|
||||||
else:
|
|
||||||
tokenizer.tokenizer = Tokenizer.from_file("tokenizer.json")
|
|
||||||
|
|
||||||
# Tokenize and cache data (Modified)
|
|
||||||
all_tokens = []
|
|
||||||
pad_id = tokenizer.pad_id
|
|
||||||
|
|
||||||
for text in ds["text"]:
|
|
||||||
tokens = tokenizer.encode(text)
|
|
||||||
tokens = tokens[:cfg.context_size] # Truncate after tokenization
|
|
||||||
pad_len = cfg.context_size - len(tokens)
|
|
||||||
all_tokens.extend(tokens + [pad_id]*pad_len) # Modified
|
|
||||||
|
|
||||||
memmap = np.memmap("dataset_cache.bin",
|
|
||||||
dtype=np.int32,
|
|
||||||
mode="w+",
|
|
||||||
shape=(len(all_tokens),))
|
|
||||||
memmap[:] = np.array(all_tokens, dtype=np.int32)
|
|
||||||
del memmap
|
|
||||||
|
|
||||||
# Test tokenizer (Modified)
|
|
||||||
test_text = "The quick brown fox jumps over the lazy dog."
|
|
||||||
print("Tokenizer test:", tokenizer.tokenizer.encode(test_text).tokens)
|
|
||||||
|
|
||||||
# Initialize model with pad_id (Modified)
|
|
||||||
model = Transformer(pad_id=tokenizer.pad_id).to(cfg.device)
|
|
||||||
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
|
|
||||||
scaler = GradScaler()
|
|
||||||
|
|
||||||
dataset = CachedDataset()
|
|
||||||
loader = DataLoader(dataset,
|
|
||||||
batch_size=cfg.batch_size,
|
|
||||||
pin_memory=True,
|
|
||||||
shuffle=True)
|
|
||||||
|
|
||||||
# Training loop (Modified loss calculation)
|
|
||||||
start = time.time()
|
|
||||||
for step, batch in enumerate(loader):
|
|
||||||
batch = batch.to(cfg.device, non_blocking=True)
|
|
||||||
|
|
||||||
inputs = batch[:, :-1]
|
|
||||||
targets = batch[:, 1:]
|
|
||||||
|
|
||||||
with autocast():
|
|
||||||
outputs = model(inputs)
|
|
||||||
loss = torch.nn.functional.cross_entropy(
|
|
||||||
outputs.reshape(-1, cfg.vocab_size),
|
|
||||||
targets.reshape(-1).long(),
|
|
||||||
ignore_index=tokenizer.pad_id # Modified
|
|
||||||
)
|
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
scaler.step(opt)
|
|
||||||
scaler.update()
|
|
||||||
opt.zero_grad()
|
|
||||||
|
|
||||||
if step % 10 == 0:
|
|
||||||
elapsed = time.time() - start
|
|
||||||
speed = (step + 1) * cfg.batch_size / elapsed
|
|
||||||
print(f"Step {step} | Loss: {loss.item():.4f} | Speed: {speed:.1f} samples/s")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
34
train_step.py
Normal file
34
train_step.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from utils import update_model_vocab
|
||||||
|
|
||||||
|
|
||||||
|
def online_train_step(model, optimizer, tokenizer, message, device):
|
||||||
|
# Ensure model can handle current vocab
|
||||||
|
update_model_vocab(model, tokenizer)
|
||||||
|
|
||||||
|
# Freeze tokenizer so it doesn't grow mid-train
|
||||||
|
tokens = tokenizer.encode(message, return_tensors=True, freeze=True).to(device)
|
||||||
|
if tokens.size(1) < 2:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Truncate long input
|
||||||
|
max_len = model.pos_emb.size(1)
|
||||||
|
if tokens.size(1) > max_len:
|
||||||
|
tokens = tokens[:, -max_len:]
|
||||||
|
|
||||||
|
x = tokens[:, :-1]
|
||||||
|
y = tokens[:, 1:]
|
||||||
|
|
||||||
|
# HARD STOP if y exceeds model vocab
|
||||||
|
vocab_size = model.token_emb.num_embeddings
|
||||||
|
assert y.max().item() < vocab_size, f"y contains token > vocab_size ({y.max().item()} >= {vocab_size})"
|
||||||
|
|
||||||
|
logits = model(x)
|
||||||
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
return loss.item()
|
86
utils.py
Normal file
86
utils.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
from torch.nn.functional import softmax
|
||||||
|
|
||||||
|
loss_log = []
|
||||||
|
|
||||||
|
|
||||||
|
def track_loss(loss):
|
||||||
|
loss_log.append(loss)
|
||||||
|
if len(loss_log) % 50 == 0:
|
||||||
|
plot_loss()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_loss():
|
||||||
|
plt.figure()
|
||||||
|
plt.plot(loss_log)
|
||||||
|
plt.title("Training Loss Over Time")
|
||||||
|
plt.xlabel("Steps")
|
||||||
|
plt.ylabel("Loss")
|
||||||
|
plt.savefig("loss_plot.png")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def update_model_vocab(model, tokenizer):
|
||||||
|
new_vocab = tokenizer.vocab_size()
|
||||||
|
old_vocab = model.token_emb.num_embeddings
|
||||||
|
d_model = model.token_emb.embedding_dim
|
||||||
|
|
||||||
|
if new_vocab > old_vocab:
|
||||||
|
# Resize token embedding
|
||||||
|
old_weights = model.token_emb.weight.data
|
||||||
|
new_emb = torch.nn.Embedding(new_vocab, d_model).to(old_weights.device)
|
||||||
|
new_emb.weight.data[:old_vocab] = old_weights
|
||||||
|
torch.nn.init.normal_(new_emb.weight.data[old_vocab:], mean=0.0, std=0.02)
|
||||||
|
model.token_emb = new_emb
|
||||||
|
|
||||||
|
# Resize output head
|
||||||
|
old_head = model.head
|
||||||
|
new_head = torch.nn.Linear(d_model, new_vocab).to(old_weights.device)
|
||||||
|
new_head.weight.data[:old_vocab] = old_head.weight.data
|
||||||
|
new_head.bias.data[:old_vocab] = old_head.bias.data
|
||||||
|
torch.nn.init.normal_(new_head.weight.data[old_vocab:], mean=0.0, std=0.02)
|
||||||
|
torch.nn.init.zeros_(new_head.bias.data[old_vocab:])
|
||||||
|
model.head = new_head
|
||||||
|
|
||||||
|
|
||||||
|
def sample_reply(model, tokenizer, input_ids, max_len=40):
|
||||||
|
model.eval()
|
||||||
|
generated = input_ids.clone()
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
for _ in range(max_len):
|
||||||
|
# Truncate input to fit positional embedding
|
||||||
|
max_seq = model.pos_emb.size(1)
|
||||||
|
if generated.size(1) > max_seq:
|
||||||
|
generated = generated[:, -max_seq:]
|
||||||
|
|
||||||
|
update_model_vocab(model, tokenizer)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logits = model(generated)
|
||||||
|
except RuntimeError as e:
|
||||||
|
print("CUDA crash in sample_reply — possible vocab mismatch")
|
||||||
|
print("Generated:", generated)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
next_token_logits = logits[0, -1, :]
|
||||||
|
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
|
||||||
|
next_token = probs.argmax(dim=-1, keepdim=True)
|
||||||
|
next_token = next_token.unsqueeze(0) # Shape: [1, 1]
|
||||||
|
generated = torch.cat((generated, next_token), dim=1)
|
||||||
|
|
||||||
|
decoded = tokenizer.decode([next_token.item()])
|
||||||
|
if decoded in ['\n', '.', '!', '?']:
|
||||||
|
break
|
||||||
|
|
||||||
|
output = generated[0].tolist()[input_ids.shape[1]:]
|
||||||
|
reply = tokenizer.decode(output).strip()
|
||||||
|
print(f"[Reply] {repr(reply)}")
|
||||||
|
return reply
|
||||||
|
|
||||||
|
|
||||||
|
def sample_thought(model, tokenizer, device, context_text, max_len=60):
|
||||||
|
prompt = f"[thinking] {context_text}"
|
||||||
|
input_ids = tokenizer.encode(prompt, return_tensors=True).to(device)
|
||||||
|
return sample_reply(model, tokenizer, input_ids, max_len=max_len)
|
Loading…
x
Reference in New Issue
Block a user