47 lines
1.3 KiB
Python
47 lines
1.3 KiB
Python
import torch
|
|
from torch.utils.data import DataLoader
|
|
from core.dataset import CharDataset
|
|
from core.model import GPT, GPTConfig
|
|
|
|
|
|
def train():
|
|
# hyperparameters
|
|
books_dir = './books'
|
|
block_size = 128
|
|
batch_size = 32
|
|
epochs = 10
|
|
lr = 3e-4
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# dataset & model
|
|
dataset = CharDataset(books_dir, block_size)
|
|
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
config = GPTConfig(
|
|
vocab_size=dataset.vocab_size,
|
|
block_size=block_size,
|
|
n_layer=6,
|
|
n_head=6,
|
|
n_embd=384
|
|
)
|
|
model = GPT(config).to(device)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
|
|
|
model.train()
|
|
for epoch in range(1, epochs + 1):
|
|
total_loss = 0.0
|
|
for xb, yb in loader:
|
|
xb, yb = xb.to(device), yb.to(device)
|
|
optimizer.zero_grad()
|
|
_, loss = model(xb, yb)
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item()
|
|
avg = total_loss / len(loader)
|
|
print(f'Epoch {epoch}/{epochs} — avg loss: {avg:.4f}')
|
|
# save checkpoint each epoch
|
|
torch.save(model.state_dict(), 'model.pth')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
train()
|