Fix: Moved the Files around due to imports not working right

Feat: Phoebe replies but it's gibbish
This is a version break because of the file structure change.
This commit is contained in:
Dan
2024-05-15 20:13:35 -04:00
parent 12071fbf61
commit 75f1116b3b
4 changed files with 116 additions and 20 deletions

35
phoebe/discord_bot.py Normal file
View File

@@ -0,0 +1,35 @@
import discord
import os
from dotenv import load_dotenv
from train_gpt_model import process_message
from gpt_model import load_model
load_dotenv()
# Discord bot token
TOKEN = os.getenv("DISCORD_TOKEN")
# Initialize Discord client
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)
@client.event
async def on_ready():
print(f"We have logged in as {client.user}")
load_model(5641, "phoebe_model.pt")
@client.event
async def on_message(message):
if message.author == client.user:
return
# Process the message and get a response
response = process_message(message.content)
# Send the response back to the Discord channel
await message.channel.send(response)
client.run(TOKEN)

View File

@@ -1,6 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import os
# Hyperparameters # Hyperparameters
batch_size = 64 batch_size = 64
@@ -123,6 +124,9 @@ class GPT(nn.Module):
for _ in range(max_new_tokens): for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:] idx_cond = idx[:, -block_size:]
logits, _ = self(idx_cond) logits, _ = self(idx_cond)
print(f"Logits shape: {logits.shape}") # Debug print
if logits.size(1) == 0:
raise ValueError("Logits tensor is empty.")
logits = logits[:, -1, :] logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1) idx_next = torch.multinomial(probs, num_samples=1)
@@ -131,8 +135,28 @@ class GPT(nn.Module):
def encode(s, string_to_int): def encode(s, string_to_int):
return [string_to_int[c] for c in s] # Replace unknown characters with a special token (e.g., "<unk>")
encoded = []
for c in s:
if c in string_to_int:
encoded.append(string_to_int[c])
else:
print(f"Unknown character encountered during encoding: {c}")
encoded.append(string_to_int["<unk>"])
return encoded
def decode(lst, int_to_string): def decode(lst, int_to_string):
return "".join([int_to_string[i] for i in lst]) return "".join([int_to_string[i] for i in lst])
def load_model(vocab_size, model_path="phoebe_model.pt"):
model = GPT(vocab_size)
if os.path.exists(model_path):
model.load_state_dict(
torch.load(model_path, map_location=torch.device("cpu"))
)
print("Model loaded successfully.")
else:
print("No pre-trained model found. Initialized a new model.")
return model

View File

@@ -1,16 +1,16 @@
import torch import torch
import mmap import mmap
import random import random
from gpt_model import GPT, encode from gpt_model import GPT, encode, decode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters # Hyperparameters
batch_size = 64 batch_size = 64
block_size = 256 block_size = 256
max_iters = 200 max_iters = 500
learning_rate = 2e-5 learning_rate = 2e-5
eval_iters = 100 eval_iters = 250
dropout = 0.2 dropout = 0.2
chars = "" chars = ""
@@ -18,12 +18,18 @@ with open("vocab.txt", "r", encoding="utf-8") as f:
text = f.read() text = f.read()
chars = sorted(list(set(text))) chars = sorted(list(set(text)))
# Ensure that space and other special characters are included
# Ensure that space and other special characters are included # Ensure that space and other special characters are included
required_chars = " \n\r\t" required_chars = " \n\r\t"
for char in required_chars: for char in required_chars:
if char not in chars: if char not in chars:
chars.append(char) chars.append(char)
# Add a special token for unknown characters
special_token = "<unk>"
if special_token not in chars:
chars.append(special_token)
vocab_size = len(chars) vocab_size = len(chars)
string_to_int = {ch: i for i, ch in enumerate(chars)} string_to_int = {ch: i for i, ch in enumerate(chars)}
int_to_string = {i: ch for i, ch in enumerate(chars)} int_to_string = {i: ch for i, ch in enumerate(chars)}
@@ -73,21 +79,52 @@ def estimate_loss():
return out return out
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) def train_model():
for iter in range(max_iters): optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
if iter % eval_iters == 0: for iter in range(max_iters):
losses = estimate_loss() if iter % eval_iters == 0:
print( losses = estimate_loss()
f"step {iter}: train loss {losses['train']:.3f}, " print(
f"val loss {losses['val']:.3f}" f"step {iter}: train loss {losses['train']:.3f}, "
) f"val loss {losses['val']:.3f}"
xb, yb = get_batch("train") )
logits, loss = model(xb, yb) xb, yb = get_batch("train")
optimizer.zero_grad(set_to_none=True) logits, loss = model(xb, yb)
loss.backward() optimizer.zero_grad(set_to_none=True)
optimizer.step() loss.backward()
optimizer.step()
print(loss.item()) print(loss.item())
torch.save(model.state_dict(), "phoebe_model.pt")
print("Model Saved!")
torch.save(model.state_dict(), "phoebe_model.pt")
print("Model Saved!") def check_input_chars(s, string_to_int):
unknown_chars = [c for c in s if c not in string_to_int]
if unknown_chars:
print(f"Unknown characters in input: {unknown_chars}")
return unknown_chars
def process_message(message):
if not message.strip():
return "Message is empty or invalid."
# Check for unknown characters
unknown_chars = check_input_chars(message, string_to_int)
if unknown_chars:
return f"Message contains unknown characters: {unknown_chars}"
encoded_text = torch.tensor(
[encode(message, string_to_int)], dtype=torch.long
).to(device)
print(f"Encoded text shape: {encoded_text.shape}") # Debug print
if encoded_text.size(1) == 0:
return "Message could not be processed."
response = model.generate(encoded_text, max_new_tokens=50)
decoded_response = decode(response[0].tolist(), int_to_string)
return decoded_response
# train_model()

Binary file not shown.