230 lines
6.9 KiB
Python
230 lines
6.9 KiB
Python
# flake8: noqa: E203
|
|
import os
|
|
import random
|
|
import re
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from gpt_model import encode, decode, load_model
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Hyperparameters
|
|
batch_size = 32 # Reduced batch size for gradient accumulation
|
|
accumulation_steps = 4 # Gradient accumulation steps
|
|
block_size = 256
|
|
max_iters = 100000 # Increased iterations
|
|
learning_rate = 3e-5 # Adjust learning rate
|
|
eval_iters = 100
|
|
dropout = 0.4 # Increased dropout to prevent overfitting
|
|
patience = 20000 # Increased patience for early stopping
|
|
weight_decay = 0.01 # Add weight decay for regularization
|
|
|
|
# Load the vocabulary and encoded data
|
|
with open("vocab.txt", "r", encoding="utf-8") as f:
|
|
text = f.read()
|
|
chars = sorted(list(set(text)))
|
|
|
|
required_chars = " \n\r\t"
|
|
for char in required_chars:
|
|
if char not in chars:
|
|
chars.append(char)
|
|
|
|
special_token = "<unk>"
|
|
if special_token not in chars:
|
|
chars.append(special_token)
|
|
|
|
vocab_size = len(chars)
|
|
string_to_int = {ch: i for i, ch in enumerate(chars)}
|
|
int_to_string = {i: ch for i, ch in enumerate(chars)}
|
|
|
|
|
|
def clean_text(text):
|
|
text = re.sub(r"[^a-zA-Z0-9\s.,;!?\'\"]+", "", text)
|
|
text = re.sub(r"\s+", " ", text)
|
|
text = text.strip()
|
|
return text
|
|
|
|
|
|
def load_and_clean_data(file_path):
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
text = f.read()
|
|
cleaned_text = clean_text(text)
|
|
return cleaned_text
|
|
|
|
|
|
train_data = load_and_clean_data("train_split_cleaned.txt")
|
|
val_data = load_and_clean_data("eval_split_cleaned.txt")
|
|
|
|
train_data = torch.tensor(encode(train_data, string_to_int), dtype=torch.long)
|
|
val_data = torch.tensor(encode(val_data, string_to_int), dtype=torch.long)
|
|
|
|
|
|
def get_random_chunk(data, chunk_size):
|
|
start = random.randint(0, len(data) - chunk_size - 1)
|
|
chunk = data[start : start + chunk_size]
|
|
return chunk
|
|
|
|
|
|
def get_batch(data, block_size, batch_size):
|
|
chunk_size = block_size * (batch_size + 1)
|
|
chunk = get_random_chunk(data, chunk_size)
|
|
x = chunk[: block_size * batch_size].view(batch_size, block_size)
|
|
y = chunk[1 : block_size * batch_size + 1].view(batch_size, block_size)
|
|
x, y = x.to(device), y.to(device)
|
|
return x, y
|
|
|
|
|
|
def load_or_initialize_model(vocab_size):
|
|
model = load_model(vocab_size)
|
|
if os.path.exists("phoebe_model.pt"):
|
|
model.load_state_dict(torch.load("phoebe_model.pt"))
|
|
print("Model loaded from phoebe_model.pt")
|
|
else:
|
|
print("Initialized a new model")
|
|
return model
|
|
|
|
|
|
model = load_or_initialize_model(vocab_size).to(device)
|
|
|
|
|
|
@torch.no_grad()
|
|
def estimate_loss():
|
|
out = {}
|
|
model.eval()
|
|
for split in ["train", "val"]:
|
|
data = train_data if split == "train" else val_data
|
|
losses = torch.zeros(eval_iters)
|
|
for k in range(eval_iters):
|
|
x, y = get_batch(data, block_size, batch_size)
|
|
logits, loss = model(x, y)
|
|
losses[k] = loss.item()
|
|
out[split] = losses.mean().item()
|
|
model.train()
|
|
return out
|
|
|
|
|
|
def train_model():
|
|
optimizer = optim.AdamW(
|
|
model.parameters(), lr=learning_rate, weight_decay=weight_decay
|
|
)
|
|
steps_per_epoch = len(train_data) // (batch_size * block_size)
|
|
epochs = max_iters // steps_per_epoch
|
|
|
|
scheduler = optim.lr_scheduler.OneCycleLR(
|
|
optimizer,
|
|
max_lr=learning_rate * 10,
|
|
steps_per_epoch=steps_per_epoch,
|
|
epochs=epochs,
|
|
)
|
|
|
|
writer = SummaryWriter(log_dir="runs/phoebe_training")
|
|
|
|
best_val_loss = float("inf")
|
|
patience_counter = 0
|
|
|
|
for iter in range(max_iters):
|
|
if iter % eval_iters == 0:
|
|
losses = estimate_loss()
|
|
print(
|
|
f"step {iter}: train loss {losses['train']:.3f}, "
|
|
f"val loss {losses['val']:.3f}"
|
|
)
|
|
|
|
writer.add_scalar("Loss/train", losses["train"], iter)
|
|
writer.add_scalar("Loss/val", losses["val"], iter)
|
|
|
|
if losses["val"] < best_val_loss:
|
|
best_val_loss = losses["val"]
|
|
patience_counter = 0
|
|
torch.save(model.state_dict(), "phoebe_model.pt")
|
|
print("Model Saved!")
|
|
else:
|
|
patience_counter += eval_iters
|
|
|
|
if patience_counter >= patience:
|
|
print("Early stopping triggered.")
|
|
break
|
|
|
|
xb, yb = get_batch(train_data, block_size, batch_size)
|
|
logits, loss = model(xb, yb)
|
|
loss = loss / accumulation_steps
|
|
loss.backward()
|
|
|
|
if (iter + 1) % accumulation_steps == 0:
|
|
optimizer.step()
|
|
optimizer.zero_grad(set_to_none=True)
|
|
scheduler.step()
|
|
|
|
if patience_counter < patience:
|
|
print("Training completed without early stopping.")
|
|
print(f"Final loss: {loss.item()}")
|
|
writer.close()
|
|
|
|
|
|
def check_input_chars(s, string_to_int):
|
|
unknown_chars = [c for c in s if c not in string_to_int]
|
|
if unknown_chars:
|
|
print(f"Unknown characters in input: {unknown_chars}")
|
|
return unknown_chars
|
|
|
|
|
|
# Maintain conversation history
|
|
conversation_history = []
|
|
|
|
|
|
def process_message(message):
|
|
global conversation_history
|
|
|
|
print(f"Processing message: '{message}'")
|
|
if not message.strip():
|
|
print("Message is empty or invalid.")
|
|
return "Message is empty or invalid."
|
|
|
|
unknown_chars = check_input_chars(message, string_to_int)
|
|
if unknown_chars:
|
|
print(f"Message contains unknown characters: {unknown_chars}")
|
|
return f"Message contains unknown characters: {unknown_chars}"
|
|
|
|
# Add the new message to the conversation history
|
|
conversation_history.append(message)
|
|
# Limit the conversation history to the last 5 messages to avoid excessive length
|
|
if len(conversation_history) > 5:
|
|
conversation_history = conversation_history[-5:]
|
|
|
|
# Concatenate the conversation history to form the input prompt
|
|
context = " ".join(conversation_history)
|
|
encoded_text = torch.tensor(
|
|
[encode(context, string_to_int)], dtype=torch.long
|
|
).to(device)
|
|
print(f"Encoded text shape: {encoded_text.shape}")
|
|
|
|
if encoded_text.size(1) == 0:
|
|
print("Message could not be processed.")
|
|
return "Message could not be processed."
|
|
|
|
with torch.no_grad():
|
|
generated_tokens = model.generate(
|
|
encoded_text, max_new_tokens=100, temperature=1.0
|
|
)
|
|
generated_tokens = generated_tokens[0, len(encoded_text[0]) :]
|
|
|
|
decoded_response = decode(generated_tokens.tolist(), int_to_string)
|
|
print(f"Generated response: '{decoded_response}'")
|
|
|
|
if decoded_response.startswith(context):
|
|
decoded_response = decoded_response[len(context) :].strip()
|
|
|
|
print(f"Final response: '{decoded_response}'")
|
|
|
|
# Add the response to the conversation history
|
|
conversation_history.append(decoded_response)
|
|
|
|
return decoded_response
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train_model()
|