Fix: Working on improving the model code to get a better learning rate than 2.5
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
|
||||
# Hyperparameters
|
||||
batch_size = 64
|
||||
@ -120,14 +119,11 @@ class GPT(nn.Module):
|
||||
loss = F.cross_entropy(logits, targets)
|
||||
return logits, loss
|
||||
|
||||
def generate(self, idx, max_new_tokens):
|
||||
def generate(self, idx, max_new_tokens, temperature=1.0):
|
||||
for _ in range(max_new_tokens):
|
||||
idx_cond = idx[:, -block_size:]
|
||||
logits, _ = self(idx_cond)
|
||||
print(f"Logits shape: {logits.shape}") # Debug print
|
||||
if logits.size(1) == 0:
|
||||
raise ValueError("Logits tensor is empty.")
|
||||
logits = logits[:, -1, :]
|
||||
logits = logits[:, -1, :] / temperature
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
idx_next = torch.multinomial(probs, num_samples=1)
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
@ -136,27 +132,19 @@ class GPT(nn.Module):
|
||||
|
||||
def encode(s, string_to_int):
|
||||
# Replace unknown characters with a special token (e.g., "<unk>")
|
||||
encoded = []
|
||||
for c in s:
|
||||
if c in string_to_int:
|
||||
encoded.append(string_to_int[c])
|
||||
else:
|
||||
print(f"Unknown character encountered during encoding: {c}")
|
||||
encoded.append(string_to_int["<unk>"])
|
||||
return encoded
|
||||
return [string_to_int.get(c, string_to_int["<unk>"]) for c in s]
|
||||
|
||||
|
||||
def decode(lst, int_to_string):
|
||||
return "".join([int_to_string[i] for i in lst])
|
||||
|
||||
|
||||
def load_model(vocab_size, model_path="phoebe_model.pt"):
|
||||
def load_model(vocab_size, model_path=None):
|
||||
model = GPT(vocab_size)
|
||||
if os.path.exists(model_path):
|
||||
model.load_state_dict(
|
||||
torch.load(model_path, map_location=torch.device("cpu"))
|
||||
)
|
||||
print("Model loaded successfully.")
|
||||
else:
|
||||
print("No pre-trained model found. Initialized a new model.")
|
||||
if model_path:
|
||||
try:
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
print("Model loaded successfully.")
|
||||
except FileNotFoundError:
|
||||
print("No pre-trained model found. Initialized a new model.")
|
||||
return model
|
||||
|
Reference in New Issue
Block a user