Created NORA. She has been designed from zero. At this point, I have determined the best hyperparamers for her to train. Next step is to help her communicate on discord and see how she handles it.
This commit is contained in:
135
train.py
Normal file
135
train.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""
|
||||
train.py
|
||||
|
||||
Training loop for Nora, with automatic mixed precision (AMP) to speed up on CUDA GPUs.
|
||||
Uses tqdm for progress bars, logging for metrics, and gradient clipping + LR scheduler.
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.amp import GradScaler, autocast # <-- updated import
|
||||
|
||||
|
||||
def train(
|
||||
model: torch.nn.Module,
|
||||
dataloader: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler,
|
||||
tokenizer,
|
||||
config,
|
||||
start_step: int = 0,
|
||||
):
|
||||
"""
|
||||
model: NoraTransformerLM
|
||||
dataloader: DataLoader for TextDataset
|
||||
optimizer: AdamW (or Adam)
|
||||
scheduler: LR scheduler with warmup
|
||||
tokenizer: CharTokenizer
|
||||
config: namespace from config.py
|
||||
start_step: if resuming from checkpoint
|
||||
"""
|
||||
|
||||
device = config.device
|
||||
model.to(device)
|
||||
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<pad>"])
|
||||
scaler = GradScaler()
|
||||
|
||||
global_step = start_step
|
||||
steps_per_epoch = len(dataloader)
|
||||
total_steps = config.epochs * steps_per_epoch
|
||||
|
||||
logging.info(
|
||||
f"[train] Starting training for {config.epochs} epochs, "
|
||||
f"{steps_per_epoch} steps/epoch, total approx {total_steps} steps."
|
||||
)
|
||||
|
||||
for epoch in range(config.epochs):
|
||||
model.train()
|
||||
epoch_loss = 0.0
|
||||
epoch_start = time.time()
|
||||
|
||||
# If you want to profile the first 100 steps, uncomment below:
|
||||
# if global_step == start_step:
|
||||
# t0 = time.time()
|
||||
|
||||
pbar = tqdm(
|
||||
enumerate(dataloader),
|
||||
total=steps_per_epoch,
|
||||
desc=f"Epoch {epoch + 1}",
|
||||
ncols=100,
|
||||
unit="step",
|
||||
)
|
||||
for step, (inputs, targets) in pbar:
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Mixed precision forward/backward (specify device_type="cuda")
|
||||
with autocast(device_type="cuda", dtype=torch.float16):
|
||||
logits = model(inputs) # (batch, seq_len, vocab_size)
|
||||
loss = criterion(
|
||||
logits.view(-1, tokenizer.vocab_size()),
|
||||
targets.view(-1),
|
||||
)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
clip_grad_norm_(model.parameters(), config.max_grad_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
scheduler.step()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
global_step += 1
|
||||
|
||||
# Log every log_interval steps
|
||||
if global_step % config.log_interval == 0:
|
||||
avg_loss = epoch_loss / (step + 1)
|
||||
ppl = math.exp(avg_loss)
|
||||
logging.info(
|
||||
f"[step {global_step}/{total_steps}] "
|
||||
f"avg_loss = {avg_loss:.4f}, ppl = {ppl:.2f}, "
|
||||
f"lr = {scheduler.get_last_lr()[0]:.2e}"
|
||||
)
|
||||
|
||||
# Save checkpoint every save_interval steps
|
||||
if global_step % config.save_interval == 0:
|
||||
from utils import save_checkpoint
|
||||
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
config.checkpoint_dir,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
|
||||
# (Optional) Profile first 100 steps
|
||||
# if global_step == start_step + 100:
|
||||
# elapsed = time.time() - t0
|
||||
# print(
|
||||
# f"[profile] avg time/step over first 100 batches: "
|
||||
# f"{elapsed/100:.4f} s"
|
||||
# )
|
||||
|
||||
epoch_time = time.time() - epoch_start
|
||||
avg_epoch_loss = epoch_loss / steps_per_epoch
|
||||
logging.info(
|
||||
f"[epoch {epoch + 1}/{config.epochs}] "
|
||||
f"avg_loss = {avg_epoch_loss:.4f}, time = {epoch_time:.1f}s"
|
||||
)
|
||||
|
||||
# Final checkpoint at end of all epochs
|
||||
from utils import save_checkpoint
|
||||
|
||||
save_checkpoint(model, optimizer, global_step, config.checkpoint_dir, tokenizer=tokenizer)
|
||||
logging.info("[train] Training complete.")
|
Reference in New Issue
Block a user