Compare commits
No commits in common. "Dev-NewModel" and "main" have entirely different histories.
Dev-NewMod
...
main
1
.gitignore
vendored
1
.gitignore
vendored
@ -169,4 +169,3 @@ cython_debug/
|
|||||||
|
|
||||||
# PyPI configuration file
|
# PyPI configuration file
|
||||||
.pypirc
|
.pypirc
|
||||||
/dataset_cache.bin
|
|
25
dashboard.py
25
dashboard.py
@ -1,25 +0,0 @@
|
|||||||
from flask import Flask, render_template_string
|
|
||||||
from debug import DebugMonitor
|
|
||||||
|
|
||||||
debug = DebugMonitor()
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/")
|
|
||||||
def home():
|
|
||||||
return render_template_string("""
|
|
||||||
<html>
|
|
||||||
<head><title>Ruby Debug Dashboard</title></head>
|
|
||||||
<body>
|
|
||||||
<h1>🧠 Ruby Live Debug</h1>
|
|
||||||
<p><b>Last Dream:</b> {{ debug.last_dream }}</p>
|
|
||||||
<p><b>Last Thought:</b> {{ debug.last_thought }}</p>
|
|
||||||
<p><b>Last Loss:</b> {{ debug.last_loss }}</p>
|
|
||||||
<p><b>Last Reply:</b> {{ debug.last_context }}</p>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
""", debug=debug)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app.run(port=5000)
|
|
29
debug.py
29
debug.py
@ -1,29 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
class DebugMonitor:
|
|
||||||
def __init__(self):
|
|
||||||
self.last_dream = ""
|
|
||||||
self.last_thought = ""
|
|
||||||
self.last_loss = 0.0
|
|
||||||
self.last_context = ""
|
|
||||||
|
|
||||||
def log_dream(self, dream):
|
|
||||||
self.last_dream = dream
|
|
||||||
self._print("💤 Dream", dream)
|
|
||||||
|
|
||||||
def log_thought(self, thought):
|
|
||||||
self.last_thought = thought
|
|
||||||
self._print("💭 Thought", thought)
|
|
||||||
|
|
||||||
def log_loss(self, loss):
|
|
||||||
self.last_loss = loss
|
|
||||||
self._print("📉 Loss", f"{loss:.4f}")
|
|
||||||
|
|
||||||
def log_context(self, context):
|
|
||||||
self.last_context = context
|
|
||||||
self._print("📖 Context", context)
|
|
||||||
|
|
||||||
def _print(self, label, content):
|
|
||||||
now = datetime.now().strftime("%H:%M:%S")
|
|
||||||
print(f"[{now}] {label}: {content}")
|
|
29
dream.py
29
dream.py
@ -1,29 +0,0 @@
|
|||||||
import torch
|
|
||||||
import time
|
|
||||||
from utils import sample_reply, update_model_vocab
|
|
||||||
from debug import DebugMonitor
|
|
||||||
|
|
||||||
debug = DebugMonitor()
|
|
||||||
|
|
||||||
|
|
||||||
def run_dream_loop(model, tokenizer, device, optimizer, train_step, interval=120):
|
|
||||||
print("Ruby is dreaming...")
|
|
||||||
while True:
|
|
||||||
reply, loss = generate_dream(model, tokenizer, device, optimizer, train_step)
|
|
||||||
print(f"[DREAM] {reply} (loss={loss:.4f})")
|
|
||||||
time.sleep(interval)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_dream(model, tokenizer, device, optimizer, train_step):
|
|
||||||
update_model_vocab(model, tokenizer)
|
|
||||||
|
|
||||||
prompt = "Ruby: "
|
|
||||||
input_ids = tokenizer.encode(prompt, return_tensors=True, freeze=True).to(device)
|
|
||||||
|
|
||||||
reply = sample_reply(model, tokenizer, input_ids)
|
|
||||||
training_text = f"User: What do you think?\nRuby: {reply}"
|
|
||||||
|
|
||||||
loss = train_step(model, optimizer, tokenizer, training_text, device)
|
|
||||||
return reply, loss
|
|
||||||
debug.log_dream(reply)
|
|
||||||
debug.log_loss(loss)
|
|
@ -1,4 +0,0 @@
|
|||||||
def basic_self_feedback(reply, user_response):
|
|
||||||
if user_response and len(user_response.strip()) > 1:
|
|
||||||
return 1.0
|
|
||||||
return -0.5
|
|
70
main.py
Normal file
70
main.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import discord
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables from .env file
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Replace with your bot token
|
||||||
|
BOT_TOKEN = os.getenv('DISCORD_TOKEN')
|
||||||
|
|
||||||
|
# Ollama configuration
|
||||||
|
OLLAMA_API_URL = 'http://192.168.1.159:11434/api/generate' # Adjust if your Ollama setup is different
|
||||||
|
|
||||||
|
# Set up the Discord client
|
||||||
|
intents = discord.Intents.default()
|
||||||
|
intents.messages = True
|
||||||
|
intents.message_content = True
|
||||||
|
|
||||||
|
client = discord.Client(intents=intents)
|
||||||
|
|
||||||
|
|
||||||
|
# Function to query Ollama
|
||||||
|
def query_ollama(prompt):
|
||||||
|
payload = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"model": "nollama/mythomax-l2-13b:Q4_K_M" # Replace with your Ollama model
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
response = requests.post(OLLAMA_API_URL, json=payload, stream=True)
|
||||||
|
if response.status_code == 200:
|
||||||
|
collected_response = ""
|
||||||
|
# Stream and parse each line of JSON from the response
|
||||||
|
for line in response.iter_lines(decode_unicode=True):
|
||||||
|
if line.strip(): # Skip empty lines
|
||||||
|
try:
|
||||||
|
data = json.loads(line) # Parse each line as JSON
|
||||||
|
collected_response += data.get("response", "")
|
||||||
|
if data.get("done", False):
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Error decoding JSON line: {line}, Error: {e}")
|
||||||
|
return collected_response.strip() or "No response from model."
|
||||||
|
else:
|
||||||
|
return f"Error: {response.status_code} - {response.text}"
|
||||||
|
except requests.RequestException as e:
|
||||||
|
return f"Error connecting to Ollama: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
# Event for when the bot is ready
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
print(f'We have logged in as {client.user}')
|
||||||
|
|
||||||
|
|
||||||
|
# Event for when a message is sent
|
||||||
|
@client.event
|
||||||
|
async def on_message(message):
|
||||||
|
# Ignore the bot's own messages
|
||||||
|
if message.author == client.user:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Respond to all messages except those in DMs
|
||||||
|
if not isinstance(message.channel, discord.DMChannel):
|
||||||
|
response = query_ollama(message.content.strip())
|
||||||
|
await message.channel.send(response)
|
||||||
|
|
||||||
|
# Run the bot
|
||||||
|
client.run(BOT_TOKEN)
|
29
memory.py
29
memory.py
@ -1,29 +0,0 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryBuffer:
|
|
||||||
def __init__(self, max_len=3, path="memory.json"):
|
|
||||||
self.path = Path(path)
|
|
||||||
self.max_len = max_len
|
|
||||||
self.memory = []
|
|
||||||
self.load()
|
|
||||||
|
|
||||||
def add(self, user_input, bot_reply):
|
|
||||||
self.memory.append(f"User: {user_input}")
|
|
||||||
self.memory.append(f"Bot: {bot_reply}")
|
|
||||||
if len(self.memory) > self.max_len * 2:
|
|
||||||
self.memory = self.memory[-self.max_len * 2:]
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def get_context(self):
|
|
||||||
return self.memory.copy()
|
|
||||||
|
|
||||||
def save(self):
|
|
||||||
with open(self.path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(self.memory, f)
|
|
||||||
|
|
||||||
def load(self):
|
|
||||||
if self.path.exists():
|
|
||||||
with open(self.path, "r", encoding="utf-8") as f:
|
|
||||||
self.memory = json.load(f)
|
|
30
model.py
30
model.py
@ -1,30 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class MiniTransformer(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
vocab_size,
|
|
||||||
d_model=256,
|
|
||||||
n_heads=4,
|
|
||||||
n_layers=4,
|
|
||||||
max_seq_len=512):
|
|
||||||
super().__init__()
|
|
||||||
self.token_emb = nn.Embedding(vocab_size, d_model)
|
|
||||||
self.pos_emb = nn.Parameter(torch.zeros(1, max_seq_len, d_model))
|
|
||||||
self.layers = nn.ModuleList([
|
|
||||||
nn.TransformerEncoderLayer(d_model=d_model,
|
|
||||||
nhead=n_heads,
|
|
||||||
batch_first=True)
|
|
||||||
for _ in range(n_layers)
|
|
||||||
])
|
|
||||||
self.ln = nn.LayerNorm(d_model)
|
|
||||||
self.head = nn.Linear(d_model, vocab_size)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, T = x.size()
|
|
||||||
x = self.token_emb(x) + self.pos_emb[:, :T]
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x)
|
|
||||||
x = self.ln(x)
|
|
||||||
return self.head(x)
|
|
@ -1,70 +0,0 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
class Personality:
|
|
||||||
def __init__(self, path="personality.json"):
|
|
||||||
self.path = Path(path)
|
|
||||||
self.data = {
|
|
||||||
"likes": [],
|
|
||||||
"dislikes": [],
|
|
||||||
"traits": [],
|
|
||||||
"curiosities": []
|
|
||||||
}
|
|
||||||
self.load()
|
|
||||||
|
|
||||||
def load(self):
|
|
||||||
if self.path.exists():
|
|
||||||
with open(self.path, "r", encoding="utf-8") as f:
|
|
||||||
self.data.update(json.load(f))
|
|
||||||
|
|
||||||
def save(self):
|
|
||||||
with open(self.path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(self.data, f, indent=2)
|
|
||||||
|
|
||||||
def learn_topic(self, text):
|
|
||||||
words = [w.lower() for w in text.split()]
|
|
||||||
for word in words:
|
|
||||||
if word.isalpha() and word not in self.data["curiosities"]:
|
|
||||||
self.data["curiosities"].append(word)
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def choose_curiosity(self):
|
|
||||||
if not self.data["curiosities"]:
|
|
||||||
return None
|
|
||||||
return random.choice(self.data["curiosities"])
|
|
||||||
|
|
||||||
def observe_input(self, message: str):
|
|
||||||
text = message.lower()
|
|
||||||
|
|
||||||
# Learn likes
|
|
||||||
if "i like" in text:
|
|
||||||
word = text.split("i like", 1)[1].strip().split()[0]
|
|
||||||
if word and word not in self.data["likes"]:
|
|
||||||
self.data["likes"].append(word)
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
# Learn dislikes
|
|
||||||
if "i hate" in text or "i don't like" in text:
|
|
||||||
for phrase in ["i hate", "i don't like"]:
|
|
||||||
if phrase in text:
|
|
||||||
word = text.split(phrase, 1)[1].strip().split()[0]
|
|
||||||
if word and word not in self.data["dislikes"]:
|
|
||||||
self.data["dislikes"].append(word)
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
# Learn traits from compliments
|
|
||||||
for trigger in ["you are", "you're", "ur"]:
|
|
||||||
if trigger in text:
|
|
||||||
fragment = text.split(trigger, 1)[1].strip().split()[0]
|
|
||||||
if fragment and fragment not in self.data["traits"]:
|
|
||||||
self.data["traits"].append(fragment)
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def reflect(self) -> str:
|
|
||||||
if not self.data["likes"] and not self.data["traits"]:
|
|
||||||
return "I'm still figuring out who I am."
|
|
||||||
likes = ', '.join(self.data["likes"][:3]) or "nothing yet"
|
|
||||||
traits = ', '.join(self.data["traits"][:3]) or "no traits yet"
|
|
||||||
return f"I'm starting to think I like {likes}. People have called me {traits}."
|
|
106
ruby.py
106
ruby.py
@ -1,106 +0,0 @@
|
|||||||
import discord
|
|
||||||
import torch
|
|
||||||
from debug import DebugMonitor
|
|
||||||
from dream import run_dream_loop
|
|
||||||
from model import MiniTransformer
|
|
||||||
from train_step import online_train_step
|
|
||||||
from tokenizer import ChildTokenizer
|
|
||||||
from feedback import basic_self_feedback
|
|
||||||
from memory import MemoryBuffer
|
|
||||||
from utils import update_model_vocab, track_loss, sample_reply, sample_thought
|
|
||||||
from personality import Personality
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(filename='ruby.log', level=logging.ERROR)
|
|
||||||
|
|
||||||
# Load environment variables
|
|
||||||
load_dotenv()
|
|
||||||
TOKEN = os.getenv('DISCORD_TOKEN')
|
|
||||||
|
|
||||||
# Initialize personality
|
|
||||||
personality = Personality()
|
|
||||||
|
|
||||||
# Initialize debug monitor
|
|
||||||
debug = DebugMonitor()
|
|
||||||
|
|
||||||
# Initialize model
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
tokenizer = ChildTokenizer()
|
|
||||||
memory = MemoryBuffer(max_len=3)
|
|
||||||
|
|
||||||
model = MiniTransformer(vocab_size=tokenizer.vocab_size()).to(device)
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
|
||||||
|
|
||||||
# Initialize Discord client
|
|
||||||
intents = discord.Intents.default()
|
|
||||||
intents.message_content = True
|
|
||||||
client = discord.Client(intents=intents)
|
|
||||||
|
|
||||||
# Start the dream loop in a separate thread
|
|
||||||
dream_thread = threading.Thread(
|
|
||||||
target=run_dream_loop,
|
|
||||||
args=(model, tokenizer, device, optimizer, online_train_step),
|
|
||||||
daemon=True
|
|
||||||
)
|
|
||||||
dream_thread.start()
|
|
||||||
|
|
||||||
|
|
||||||
# Event handlers
|
|
||||||
@client.event
|
|
||||||
async def on_ready():
|
|
||||||
print(f"{client.user} is ready and learning!")
|
|
||||||
|
|
||||||
|
|
||||||
@client.event
|
|
||||||
async def on_message(message):
|
|
||||||
try:
|
|
||||||
# Ignore bot's own messages
|
|
||||||
if message.author == client.user:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get user input and memory
|
|
||||||
user_input = message.content
|
|
||||||
context = memory.get_context()
|
|
||||||
full_input = ' '.join(context + [user_input])
|
|
||||||
|
|
||||||
# 🔍 Debug: log context
|
|
||||||
debug.log_context(full_input)
|
|
||||||
|
|
||||||
# Ensure model matches tokenizer
|
|
||||||
update_model_vocab(model, tokenizer)
|
|
||||||
|
|
||||||
# Encode user input
|
|
||||||
input_ids = tokenizer.encode(full_input, return_tensors=True, freeze=True).to(device)
|
|
||||||
if input_ids.size(1) < 2:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 💭 Generate internal thought
|
|
||||||
thought = sample_thought(model, tokenizer, device, full_input)
|
|
||||||
debug.log_thought(thought)
|
|
||||||
|
|
||||||
# 🗣️ Generate reply from Ruby
|
|
||||||
reply = sample_reply(model, tokenizer, input_ids)
|
|
||||||
debug.log_context(reply)
|
|
||||||
|
|
||||||
# ✅ Send the reply
|
|
||||||
await message.channel.send(reply if reply.strip() else "...")
|
|
||||||
|
|
||||||
# Add to memory
|
|
||||||
memory.add(user_input, reply)
|
|
||||||
|
|
||||||
# 📉 Train and log loss
|
|
||||||
training_example = f"User: {user_input}\nRuby: {reply}"
|
|
||||||
loss = online_train_step(model, optimizer, tokenizer, training_example, device)
|
|
||||||
debug.log_loss(loss)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("Error in on_message")
|
|
||||||
await message.channel.send("Oops, I had a brain freeze.")
|
|
||||||
|
|
||||||
|
|
||||||
client.run(TOKEN)
|
|
42
tokenizer.py
42
tokenizer.py
@ -1,42 +0,0 @@
|
|||||||
import torch
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger("tokenizer")
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
fh = logging.FileHandler("learned_chars.log")
|
|
||||||
formatter = logging.Formatter('%(message)s')
|
|
||||||
fh.setFormatter(formatter)
|
|
||||||
logger.addHandler(fh)
|
|
||||||
|
|
||||||
|
|
||||||
class ChildTokenizer:
|
|
||||||
def __init__(self):
|
|
||||||
self.char_to_id = {'<pad>': 0, '<unk>': 1}
|
|
||||||
self.id_to_char = {0: '<pad>', 1: '<unk>'}
|
|
||||||
self.next_id = 2
|
|
||||||
|
|
||||||
# 🔤 Bootstrap with common characters
|
|
||||||
for ch in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.,!? ':;":
|
|
||||||
self.char_to_id[ch] = self.next_id
|
|
||||||
self.id_to_char[self.next_id] = ch
|
|
||||||
self.next_id += 1
|
|
||||||
|
|
||||||
def encode(self, text, return_tensors=False, freeze=False):
|
|
||||||
ids = []
|
|
||||||
for ch in text:
|
|
||||||
if ch not in self.char_to_id:
|
|
||||||
if freeze:
|
|
||||||
ids.append(self.char_to_id.get('<unk>', 1))
|
|
||||||
continue
|
|
||||||
self.char_to_id[ch] = self.next_id
|
|
||||||
self.id_to_char[self.next_id] = ch
|
|
||||||
self.next_id += 1
|
|
||||||
ids.append(self.char_to_id[ch])
|
|
||||||
return torch.tensor([ids], dtype=torch.long) if return_tensors else ids
|
|
||||||
|
|
||||||
def decode(self, ids):
|
|
||||||
return ''.join([self.id_to_char.get(i, '<unk>') for i in ids])
|
|
||||||
|
|
||||||
def vocab_size(self):
|
|
||||||
return self.next_id
|
|
@ -1,34 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from utils import update_model_vocab
|
|
||||||
|
|
||||||
|
|
||||||
def online_train_step(model, optimizer, tokenizer, message, device):
|
|
||||||
# Ensure model can handle current vocab
|
|
||||||
update_model_vocab(model, tokenizer)
|
|
||||||
|
|
||||||
# Freeze tokenizer so it doesn't grow mid-train
|
|
||||||
tokens = tokenizer.encode(message, return_tensors=True, freeze=True).to(device)
|
|
||||||
if tokens.size(1) < 2:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
# Truncate long input
|
|
||||||
max_len = model.pos_emb.size(1)
|
|
||||||
if tokens.size(1) > max_len:
|
|
||||||
tokens = tokens[:, -max_len:]
|
|
||||||
|
|
||||||
x = tokens[:, :-1]
|
|
||||||
y = tokens[:, 1:]
|
|
||||||
|
|
||||||
# HARD STOP if y exceeds model vocab
|
|
||||||
vocab_size = model.token_emb.num_embeddings
|
|
||||||
assert y.max().item() < vocab_size, f"y contains token > vocab_size ({y.max().item()} >= {vocab_size})"
|
|
||||||
|
|
||||||
logits = model(x)
|
|
||||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
return loss.item()
|
|
86
utils.py
86
utils.py
@ -1,86 +0,0 @@
|
|||||||
import matplotlib.pyplot as plt
|
|
||||||
import torch
|
|
||||||
from torch.nn.functional import softmax
|
|
||||||
|
|
||||||
loss_log = []
|
|
||||||
|
|
||||||
|
|
||||||
def track_loss(loss):
|
|
||||||
loss_log.append(loss)
|
|
||||||
if len(loss_log) % 50 == 0:
|
|
||||||
plot_loss()
|
|
||||||
|
|
||||||
|
|
||||||
def plot_loss():
|
|
||||||
plt.figure()
|
|
||||||
plt.plot(loss_log)
|
|
||||||
plt.title("Training Loss Over Time")
|
|
||||||
plt.xlabel("Steps")
|
|
||||||
plt.ylabel("Loss")
|
|
||||||
plt.savefig("loss_plot.png")
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
|
|
||||||
def update_model_vocab(model, tokenizer):
|
|
||||||
new_vocab = tokenizer.vocab_size()
|
|
||||||
old_vocab = model.token_emb.num_embeddings
|
|
||||||
d_model = model.token_emb.embedding_dim
|
|
||||||
|
|
||||||
if new_vocab > old_vocab:
|
|
||||||
# Resize token embedding
|
|
||||||
old_weights = model.token_emb.weight.data
|
|
||||||
new_emb = torch.nn.Embedding(new_vocab, d_model).to(old_weights.device)
|
|
||||||
new_emb.weight.data[:old_vocab] = old_weights
|
|
||||||
torch.nn.init.normal_(new_emb.weight.data[old_vocab:], mean=0.0, std=0.02)
|
|
||||||
model.token_emb = new_emb
|
|
||||||
|
|
||||||
# Resize output head
|
|
||||||
old_head = model.head
|
|
||||||
new_head = torch.nn.Linear(d_model, new_vocab).to(old_weights.device)
|
|
||||||
new_head.weight.data[:old_vocab] = old_head.weight.data
|
|
||||||
new_head.bias.data[:old_vocab] = old_head.bias.data
|
|
||||||
torch.nn.init.normal_(new_head.weight.data[old_vocab:], mean=0.0, std=0.02)
|
|
||||||
torch.nn.init.zeros_(new_head.bias.data[old_vocab:])
|
|
||||||
model.head = new_head
|
|
||||||
|
|
||||||
|
|
||||||
def sample_reply(model, tokenizer, input_ids, max_len=40):
|
|
||||||
model.eval()
|
|
||||||
generated = input_ids.clone()
|
|
||||||
device = input_ids.device
|
|
||||||
|
|
||||||
for _ in range(max_len):
|
|
||||||
# Truncate input to fit positional embedding
|
|
||||||
max_seq = model.pos_emb.size(1)
|
|
||||||
if generated.size(1) > max_seq:
|
|
||||||
generated = generated[:, -max_seq:]
|
|
||||||
|
|
||||||
update_model_vocab(model, tokenizer)
|
|
||||||
|
|
||||||
try:
|
|
||||||
logits = model(generated)
|
|
||||||
except RuntimeError as e:
|
|
||||||
print("CUDA crash in sample_reply — possible vocab mismatch")
|
|
||||||
print("Generated:", generated)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
next_token_logits = logits[0, -1, :]
|
|
||||||
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
|
|
||||||
next_token = probs.argmax(dim=-1, keepdim=True)
|
|
||||||
next_token = next_token.unsqueeze(0) # Shape: [1, 1]
|
|
||||||
generated = torch.cat((generated, next_token), dim=1)
|
|
||||||
|
|
||||||
decoded = tokenizer.decode([next_token.item()])
|
|
||||||
if decoded in ['\n', '.', '!', '?']:
|
|
||||||
break
|
|
||||||
|
|
||||||
output = generated[0].tolist()[input_ids.shape[1]:]
|
|
||||||
reply = tokenizer.decode(output).strip()
|
|
||||||
print(f"[Reply] {repr(reply)}")
|
|
||||||
return reply
|
|
||||||
|
|
||||||
|
|
||||||
def sample_thought(model, tokenizer, device, context_text, max_len=60):
|
|
||||||
prompt = f"[thinking] {context_text}"
|
|
||||||
input_ids = tokenizer.encode(prompt, return_tensors=True).to(device)
|
|
||||||
return sample_reply(model, tokenizer, input_ids, max_len=max_len)
|
|
Loading…
x
Reference in New Issue
Block a user