# flake8: noqa: E203 import os import random import re import torch import torch.optim as optim from torch.utils.tensorboard import SummaryWriter from gpt_model import encode, decode, load_model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Hyperparameters batch_size = 32 # Reduced batch size for gradient accumulation accumulation_steps = 4 # Gradient accumulation steps block_size = 256 max_iters = 100000 # Increased iterations learning_rate = 3e-5 # Adjust learning rate eval_iters = 100 dropout = 0.4 # Increased dropout to prevent overfitting patience = 20000 # Increased patience for early stopping weight_decay = 0.01 # Add weight decay for regularization # Load the vocabulary and encoded data with open("vocab.txt", "r", encoding="utf-8") as f: text = f.read() chars = sorted(list(set(text))) required_chars = " \n\r\t" for char in required_chars: if char not in chars: chars.append(char) special_token = "" if special_token not in chars: chars.append(special_token) vocab_size = len(chars) string_to_int = {ch: i for i, ch in enumerate(chars)} int_to_string = {i: ch for i, ch in enumerate(chars)} def clean_text(text): text = re.sub(r"[^a-zA-Z0-9\s.,;!?\'\"]+", "", text) text = re.sub(r"\s+", " ", text) text = text.strip() return text def load_and_clean_data(file_path): with open(file_path, "r", encoding="utf-8") as f: text = f.read() cleaned_text = clean_text(text) return cleaned_text train_data = load_and_clean_data("train_split_cleaned.txt") val_data = load_and_clean_data("eval_split_cleaned.txt") train_data = torch.tensor(encode(train_data, string_to_int), dtype=torch.long) val_data = torch.tensor(encode(val_data, string_to_int), dtype=torch.long) def get_random_chunk(data, chunk_size): start = random.randint(0, len(data) - chunk_size - 1) chunk = data[start : start + chunk_size] return chunk def get_batch(data, block_size, batch_size): chunk_size = block_size * (batch_size + 1) chunk = get_random_chunk(data, chunk_size) x = chunk[: block_size * batch_size].view(batch_size, block_size) y = chunk[1 : block_size * batch_size + 1].view(batch_size, block_size) x, y = x.to(device), y.to(device) return x, y def load_or_initialize_model(vocab_size): model = load_model(vocab_size) if os.path.exists("phoebe_model.pt"): model.load_state_dict(torch.load("phoebe_model.pt")) print("Model loaded from phoebe_model.pt") else: print("Initialized a new model") return model model = load_or_initialize_model(vocab_size).to(device) @torch.no_grad() def estimate_loss(): out = {} model.eval() for split in ["train", "val"]: data = train_data if split == "train" else val_data losses = torch.zeros(eval_iters) for k in range(eval_iters): x, y = get_batch(data, block_size, batch_size) logits, loss = model(x, y) losses[k] = loss.item() out[split] = losses.mean().item() model.train() return out def train_model(): optimizer = optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=weight_decay ) steps_per_epoch = len(train_data) // (batch_size * block_size) epochs = max_iters // steps_per_epoch scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=learning_rate * 10, steps_per_epoch=steps_per_epoch, epochs=epochs, ) writer = SummaryWriter(log_dir="runs/phoebe_training") best_val_loss = float("inf") patience_counter = 0 for iter in range(max_iters): if iter % eval_iters == 0: losses = estimate_loss() print( f"step {iter}: train loss {losses['train']:.3f}, " f"val loss {losses['val']:.3f}" ) writer.add_scalar("Loss/train", losses["train"], iter) writer.add_scalar("Loss/val", losses["val"], iter) if losses["val"] < best_val_loss: best_val_loss = losses["val"] patience_counter = 0 torch.save(model.state_dict(), "phoebe_model.pt") print("Model Saved!") else: patience_counter += eval_iters if patience_counter >= patience: print("Early stopping triggered.") break xb, yb = get_batch(train_data, block_size, batch_size) logits, loss = model(xb, yb) loss = loss / accumulation_steps loss.backward() if (iter + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad(set_to_none=True) scheduler.step() if patience_counter < patience: print("Training completed without early stopping.") print(f"Final loss: {loss.item()}") writer.close() def check_input_chars(s, string_to_int): unknown_chars = [c for c in s if c not in string_to_int] if unknown_chars: print(f"Unknown characters in input: {unknown_chars}") return unknown_chars def process_message(message): print(f"Processing message: '{message}'") if not message.strip(): print("Message is empty or invalid.") return "Message is empty or invalid." unknown_chars = check_input_chars(message, string_to_int) if unknown_chars: print(f"Message contains unknown characters: {unknown_chars}") return f"Message contains unknown characters: {unknown_chars}" encoded_text = torch.tensor( [encode(message, string_to_int)], dtype=torch.long ).to(device) print(f"Encoded text shape: {encoded_text.shape}") if encoded_text.size(1) == 0: print("Message could not be processed.") return "Message could not be processed." with torch.no_grad(): generated_tokens = model.generate( encoded_text, max_new_tokens=50, temperature=0.7 ) generated_tokens = generated_tokens[0, len(encoded_text[0]) :] decoded_response = decode(generated_tokens.tolist(), int_to_string) print(f"Generated response: '{decoded_response}'") if decoded_response.startswith(message): decoded_response = decoded_response[len(message) :].strip() print(f"Final response: '{decoded_response}'") return decoded_response if __name__ == "__main__": train_model()