90 lines
2.7 KiB
Python
90 lines
2.7 KiB
Python
"""
|
|
main.py
|
|
|
|
Orchestrates the entire Nora project:
|
|
- Parses arguments
|
|
- Builds or loads tokenizer
|
|
- Constructs dataset & dataloader
|
|
- Instantiates the model
|
|
- Sets up optimizer, scheduler
|
|
- Calls train()
|
|
"""
|
|
|
|
import os
|
|
import torch
|
|
import logging
|
|
from config import get_config
|
|
from tokenizer import CharTokenizer
|
|
from data_loader import get_dataloader
|
|
from model import NoraTransformerLM
|
|
from train import train
|
|
from utils import setup_logging, load_checkpoint, save_checkpoint
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
torch.backends.cudnn.enabled = True
|
|
|
|
|
|
def main():
|
|
args = get_config()
|
|
|
|
# 1) Logging setup
|
|
log_file = os.path.join(args.checkpoint_dir, "train.log")
|
|
setup_logging(log_file)
|
|
|
|
logging.info(f"[main] Using device: {args.device}")
|
|
logging.info(f"[main] Config: {args}")
|
|
|
|
# 2) Tokenizer: if vocab exists, load; else build from data_dir
|
|
tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
|
|
|
|
# 3) DataLoader
|
|
dataloader = get_dataloader(
|
|
data_dir=args.data_dir,
|
|
tokenizer=tokenizer,
|
|
seq_length=args.seq_length,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
)
|
|
|
|
# 4) Model instantiation
|
|
model = NoraTransformerLM(
|
|
vocab_size=tokenizer.vocab_size(),
|
|
d_model=args.d_model,
|
|
nhead=args.nhead,
|
|
num_layers=args.num_layers,
|
|
dim_feedforward=args.dim_feedforward,
|
|
dropout=args.dropout,
|
|
max_seq_len=args.seq_length,
|
|
)
|
|
|
|
# 5) Optimizer & scheduler (linear warmup + decay)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
|
|
|
|
def lr_lambda(current_step):
|
|
# Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
|
|
if current_step < args.warmup_steps:
|
|
return float(current_step) / float(max(1, args.warmup_steps))
|
|
return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
|
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
|
|
# 6) Check for existing checkpoint to resume
|
|
start_step = 0
|
|
ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
|
|
ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
|
|
if ckpts:
|
|
latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
|
|
logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
|
|
start_step = load_checkpoint(latest_ckpt, model, optimizer)
|
|
|
|
# 7) Begin training
|
|
try:
|
|
train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
|
|
except Exception as e:
|
|
logging.exception("[main] Exception during training")
|
|
raise e
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|