Files
Pheobe/phoebe/train_gpt_model.py

182 lines
5.6 KiB
Python

import re
import torch
import torch.optim as optim
import random
import os
from gpt_model import encode, decode, load_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
batch_size = 64
block_size = 256
max_iters = 5000
learning_rate = 1e-5 # Adjusted learning rate
eval_iters = 100
dropout = 0.2
patience = 500 # Number of iterations to wait for improvement before stopping
# Load the vocabulary and encoded data
with open("vocab.txt", "r", encoding="utf-8") as f:
text = f.read()
chars = sorted(list(set(text)))
# Ensure that space and other special characters are included
required_chars = " \n\r\t"
for char in required_chars:
if char not in chars:
chars.append(char)
# Add a special token for unknown characters
special_token = "<unk>"
if special_token not in chars:
chars.append(special_token)
vocab_size = len(chars)
string_to_int = {ch: i for i, ch in enumerate(chars)}
int_to_string = {i: ch for i, ch in enumerate(chars)}
def clean_text(text):
"""Remove special characters and unwanted symbols from the text."""
text = re.sub(r"[^a-zA-Z0-9\s.,;!?\'\"]+", "", text)
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
# Load and preprocess training and validation data from cleaned .txt files
def load_and_clean_data(file_path):
with open(file_path, "r", encoding="utf-8") as f:
text = f.read()
cleaned_text = clean_text(text)
return cleaned_text
train_data = load_and_clean_data("train_split_cleaned.txt")
val_data = load_and_clean_data("eval_split_cleaned.txt")
train_data = torch.tensor(encode(train_data, string_to_int), dtype=torch.long)
val_data = torch.tensor(encode(val_data, string_to_int), dtype=torch.long)
def get_random_chunk(data, chunk_size):
start = random.randint(0, len(data) - chunk_size - 1)
chunk = data[start : start + chunk_size]
return chunk
def get_batch(data, block_size, batch_size):
chunk_size = block_size * (batch_size + 1)
chunk = get_random_chunk(data, chunk_size)
x = chunk[: block_size * batch_size].view(batch_size, block_size)
y = chunk[1 : block_size * batch_size + 1].view(batch_size, block_size)
x, y = x.to(device), y.to(device)
return x, y
def load_or_initialize_model(vocab_size):
model = load_model(vocab_size)
if os.path.exists("phoebe_model.pt"):
model.load_state_dict(torch.load("phoebe_model.pt"))
print("Model loaded from phoebe_model.pt")
else:
print("Initialized a new model")
return model
model = load_or_initialize_model(vocab_size).to(device)
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ["train", "val"]:
data = train_data if split == "train" else val_data
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(data, block_size, batch_size)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean().item()
model.train()
return out
def train_model():
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)
best_val_loss = float("inf")
patience_counter = 0
for iter in range(max_iters):
if iter % eval_iters == 0:
losses = estimate_loss()
print(
f"step {iter}: train loss {losses['train']:.3f}, "
f"val loss {losses['val']:.3f}"
)
# Check for improvement in validation loss
if losses["val"] < best_val_loss:
best_val_loss = losses["val"]
patience_counter = 0
torch.save(model.state_dict(), "phoebe_model.pt")
print("Model Saved!")
else:
patience_counter += eval_iters
# Early stopping
if patience_counter >= patience:
print("Early stopping triggered.")
break
xb, yb = get_batch(train_data, block_size, batch_size)
logits, loss = model(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
scheduler.step()
if patience_counter < patience:
print("Training completed without early stopping.")
print(f"Final loss: {loss.item()}")
def check_input_chars(s, string_to_int):
unknown_chars = [c for c in s if c not in string_to_int]
if unknown_chars:
print(f"Unknown characters in input: {unknown_chars}")
return unknown_chars
def process_message(message):
print(f"Processing message: '{message}'") # Debug print
if not message.strip():
print("Message is empty or invalid.") # Debug print
return "Message is empty or invalid."
# Check for unknown characters
unknown_chars = check_input_chars(message, string_to_int)
if unknown_chars:
print(f"Message contains unknown characters: {unknown_chars}")
return f"Message contains unknown characters: {unknown_chars}"
encoded_text = torch.tensor(
[encode(message, string_to_int)], dtype=torch.long
).to(device)
print(f"Encoded text shape: {encoded_text.shape}") # Debug print
if encoded_text.size(1) == 0:
print("Message could not be processed.") # Debug print
return "Message could not be processed."
response = model.generate(encoded_text, max_new_tokens=50, temperature=0.7)
decoded_response = decode(response[0].tolist(), int_to_string)
print(f"Generated response: '{decoded_response}'") # Debug print
return decoded_response
if __name__ == "__main__":
train_model()