Stage one of the project, done
This commit is contained in:
53
configs/training.yaml
Normal file
53
configs/training.yaml
Normal file
@@ -0,0 +1,53 @@
|
||||
# Training Configuration for Lyra
|
||||
|
||||
training:
|
||||
# Model selection
|
||||
model_config: "configs/model_125M.yaml" # Start with 125M
|
||||
|
||||
# Data
|
||||
train_data_path: "data/processed/train.bin"
|
||||
val_data_path: "data/processed/val.bin"
|
||||
|
||||
# Training hyperparameters
|
||||
batch_size: 8 # Adjust based on VRAM
|
||||
gradient_accumulation_steps: 4
|
||||
effective_batch_size: 32 # batch_size * grad_accum_steps
|
||||
|
||||
max_steps: 100000
|
||||
warmup_steps: 2000
|
||||
eval_interval: 1000
|
||||
save_interval: 5000
|
||||
|
||||
# Optimization
|
||||
learning_rate: 6.0e-4
|
||||
weight_decay: 0.1
|
||||
beta1: 0.9
|
||||
beta2: 0.95
|
||||
grad_clip: 1.0
|
||||
|
||||
# Learning rate schedule
|
||||
lr_scheduler: "cosine"
|
||||
min_lr: 6.0e-5 # 10% of max lr
|
||||
|
||||
# Mixed precision
|
||||
use_amp: true
|
||||
amp_dtype: "bfloat16" # bfloat16 or float16
|
||||
|
||||
# Optimization techniques
|
||||
gradient_checkpointing: true
|
||||
compile_model: false # PyTorch 2.0 compilation (can cause issues)
|
||||
|
||||
# Logging
|
||||
log_interval: 10
|
||||
wandb_project: "lyra-training"
|
||||
wandb_run_name: null # Auto-generated if null
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_dir: "models/checkpoints"
|
||||
save_optimizer_state: true
|
||||
keep_last_n_checkpoints: 3
|
||||
|
||||
# Hardware
|
||||
device: "cuda"
|
||||
num_workers: 4
|
||||
pin_memory: true
|
Reference in New Issue
Block a user