Added another learning source for Nora. Also added the requirements.
This commit is contained in:
89
pretrain.py
Normal file
89
pretrain.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""
|
||||
pretrain.py
|
||||
|
||||
Orchestrates the entire Nora project:
|
||||
- Parses arguments
|
||||
- Builds or loads tokenizer
|
||||
- Constructs dataset & dataloader
|
||||
- Instantiates the model
|
||||
- Sets up optimizer, scheduler
|
||||
- Calls train()
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import logging
|
||||
from config import get_config
|
||||
from tokenizer import CharTokenizer
|
||||
from data_loader import get_dataloader
|
||||
from model import NoraTransformerLM
|
||||
from train import train
|
||||
from utils import setup_logging, load_checkpoint, save_checkpoint
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
|
||||
def pretrain():
|
||||
args = get_config()
|
||||
|
||||
# 1) Logging setup
|
||||
log_file = os.path.join(args.checkpoint_dir, "train.log")
|
||||
setup_logging(log_file)
|
||||
|
||||
logging.info(f"[pretrain] Using device: {args.device}")
|
||||
logging.info(f"[pretrain] Config: {args}")
|
||||
|
||||
# 2) Tokenizer: if vocab exists, load; else build from data_dir
|
||||
tokenizer = CharTokenizer(vocab_path=args.vocab_path, data_dir=args.data_dir)
|
||||
|
||||
# 3) DataLoader
|
||||
dataloader = get_dataloader(
|
||||
data_dir=args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
seq_length=args.seq_length,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
# 4) Model instantiation
|
||||
model = NoraTransformerLM(
|
||||
vocab_size=tokenizer.vocab_size(),
|
||||
d_model=args.d_model,
|
||||
nhead=args.nhead,
|
||||
num_layers=args.num_layers,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
dropout=args.dropout,
|
||||
max_seq_len=args.seq_length,
|
||||
)
|
||||
|
||||
# 5) Optimizer & scheduler (linear warmup + decay)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
|
||||
|
||||
def lr_lambda(current_step):
|
||||
# Linear warmup for first warmup_steps, then decay with 1/sqrt(step)
|
||||
if current_step < args.warmup_steps:
|
||||
return float(current_step) / float(max(1, args.warmup_steps))
|
||||
return (args.warmup_steps ** 0.5) * float(current_step ** -0.5)
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
# 6) Check for existing checkpoint to resume
|
||||
start_step = 0
|
||||
ckpts = sorted(os.listdir(args.checkpoint_dir)) if os.path.isdir(args.checkpoint_dir) else []
|
||||
ckpts = [f for f in ckpts if f.startswith("nora_step_") and f.endswith(".pt")]
|
||||
if ckpts:
|
||||
latest_ckpt = os.path.join(args.checkpoint_dir, ckpts[-1])
|
||||
logging.info(f"[main] Found existing checkpoint: {latest_ckpt}; resuming from it.")
|
||||
start_step = load_checkpoint(latest_ckpt, model, optimizer)
|
||||
|
||||
# 7) Begin training
|
||||
try:
|
||||
train(model, dataloader, optimizer, scheduler, tokenizer, args, start_step=start_step)
|
||||
except Exception as e:
|
||||
logging.exception("[main] Exception during training")
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pretrain()
|
Reference in New Issue
Block a user