RubyOld/train.py

160 lines
5.1 KiB
Python

import torch
import torch.nn as nn
import time
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers, decoders
from config import cfg
from torch.cuda.amp import autocast, GradScaler
# 1. Tokenizer Implementation (Modified)
class RubyTokenizer:
def __init__(self):
self.tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
self.tokenizer.add_special_tokens(["[PAD]", "[UNK]"])
self.tokenizer.decoder = decoders.ByteLevel()
def train(self, texts):
trainer = trainers.BpeTrainer(
special_tokens=["[PAD]", "[UNK]"],
vocab_size=cfg.vocab_size,
min_frequency=2, # Modified
show_progress=True
)
self.tokenizer.train_from_iterator(
(text.split() for text in texts), # Modified: better word handling
trainer=trainer
)
def encode(self, text):
return self.tokenizer.encode(text).ids
@property
def pad_id(self):
return self.tokenizer.token_to_id("[PAD]") # Modified
# 2. Optimized Dataset (Modified padding handling)
class CachedDataset(Dataset):
def __init__(self):
self.data = np.memmap("dataset_cache.bin",
dtype=np.int32,
mode="r",
shape=(os.path.getsize("dataset_cache.bin")//4,))
def __len__(self):
return len(self.data) // cfg.context_size
def __getitem__(self, idx):
start = idx * cfg.context_size
return torch.from_numpy(self.data[start:start+cfg.context_size].copy())
# 3. Transformer Model (Modified padding_idx)
class Transformer(nn.Module):
def __init__(self, pad_id):
super().__init__()
self.embed = nn.Embedding(
cfg.vocab_size,
cfg.model_dim,
padding_idx=pad_id # Modified
)
self.blocks = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=cfg.model_dim,
nhead=cfg.num_heads,
dim_feedforward=cfg.model_dim*4,
batch_first=True
) for _ in range(cfg.num_layers)
])
self.head = nn.Linear(cfg.model_dim, cfg.vocab_size)
def forward(self, x):
x = self.embed(x)
for block in self.blocks:
x = block(x)
return self.head(x)
# 4. Main Training Process (Critical fixes)
def main():
# Initialize tokenizer
tokenizer = RubyTokenizer()
if not os.path.exists("dataset_cache.bin"):
print("Creating dataset cache...")
ds = load_dataset("openwebtext", split="train[:5%]")
# Train and save tokenizer (Modified)
if not os.path.exists("tokenizer.json"):
print("Training tokenizer...")
tokenizer.train([text for text in ds["text"] if len(text) > 100])
tokenizer.tokenizer.save("tokenizer.json")
else:
tokenizer.tokenizer = Tokenizer.from_file("tokenizer.json")
# Tokenize and cache data (Modified)
all_tokens = []
pad_id = tokenizer.pad_id
for text in ds["text"]:
tokens = tokenizer.encode(text)
tokens = tokens[:cfg.context_size] # Truncate after tokenization
pad_len = cfg.context_size - len(tokens)
all_tokens.extend(tokens + [pad_id]*pad_len) # Modified
memmap = np.memmap("dataset_cache.bin",
dtype=np.int32,
mode="w+",
shape=(len(all_tokens),))
memmap[:] = np.array(all_tokens, dtype=np.int32)
del memmap
# Test tokenizer (Modified)
test_text = "The quick brown fox jumps over the lazy dog."
print("Tokenizer test:", tokenizer.tokenizer.encode(test_text).tokens)
# Initialize model with pad_id (Modified)
model = Transformer(pad_id=tokenizer.pad_id).to(cfg.device)
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
scaler = GradScaler()
dataset = CachedDataset()
loader = DataLoader(dataset,
batch_size=cfg.batch_size,
pin_memory=True,
shuffle=True)
# Training loop (Modified loss calculation)
start = time.time()
for step, batch in enumerate(loader):
batch = batch.to(cfg.device, non_blocking=True)
inputs = batch[:, :-1]
targets = batch[:, 1:]
with autocast():
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(
outputs.reshape(-1, cfg.vocab_size),
targets.reshape(-1).long(),
ignore_index=tokenizer.pad_id # Modified
)
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
opt.zero_grad()
if step % 10 == 0:
elapsed = time.time() - start
speed = (step + 1) * cfg.batch_size / elapsed
print(f"Step {step} | Loss: {loss.item():.4f} | Speed: {speed:.1f} samples/s")
if __name__ == "__main__":
main()