""" Model configuration for NOVA transformer """ from dataclasses import dataclass from typing import Optional @dataclass class ModelConfig: """Configuration for NOVA transformer model""" # Model architecture vocab_size: int = 32000 hidden_size: int = 768 num_hidden_layers: int = 12 num_attention_heads: int = 12 intermediate_size: int = 3072 max_position_embeddings: int = 2048 # Activation and normalization hidden_act: str = "swiglu" # or "gelu" norm_type: str = "rmsnorm" # or "layernorm" rms_norm_eps: float = 1e-6 # Positional encoding rope_theta: float = 10000.0 use_rope: bool = True use_alibi: bool = False # Alternative to RoPE # Attention attention_dropout: float = 0.0 hidden_dropout: float = 0.1 num_key_value_heads: Optional[int] = None # For grouped-query attention (GQA) use_flash_attention: bool = False # Auto-detected at runtime # Training initializer_range: float = 0.02 use_cache: bool = True # KV-cache for inference # Efficiency gradient_checkpointing: bool = False tie_word_embeddings: bool = False def __post_init__(self): """Validate and set derived values""" if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads assert self.hidden_size % self.num_attention_heads == 0, \ f"hidden_size ({self.hidden_size}) must be divisible by num_attention_heads ({self.num_attention_heads})" assert self.num_attention_heads % self.num_key_value_heads == 0, \ f"num_attention_heads ({self.num_attention_heads}) must be divisible by num_key_value_heads ({self.num_key_value_heads})" # Predefined model sizes MODEL_125M = ModelConfig( vocab_size=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, max_position_embeddings=2048, ) MODEL_350M = ModelConfig( vocab_size=32000, hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, intermediate_size=4096, max_position_embeddings=2048, ) MODEL_1_3B = ModelConfig( vocab_size=32000, hidden_size=2048, num_hidden_layers=24, num_attention_heads=32, intermediate_size=8192, max_position_embeddings=2048, num_key_value_heads=8, # GQA for efficiency ) MODEL_3B = ModelConfig( vocab_size=32000, hidden_size=2560, num_hidden_layers=32, num_attention_heads=32, intermediate_size=10240, max_position_embeddings=4096, num_key_value_heads=8, # GQA for efficiency )