""" Normalization layers for NOVA """ import torch import torch.nn as nn class RMSNorm(nn.Module): """ Root Mean Square Layer Normalization More efficient than LayerNorm, used in LLaMA and other modern LLMs """ def __init__(self, hidden_size: int, eps: float = 1e-6): """ Args: hidden_size: Size of the hidden dimension eps: Small constant for numerical stability """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Apply RMS normalization Args: hidden_states: Input tensor [..., hidden_size] Returns: Normalized tensor """ input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) # Compute RMS variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) return self.weight * hidden_states.to(input_dtype) class LayerNorm(nn.LayerNorm): """ Standard LayerNorm with optional bias Wrapper around PyTorch's LayerNorm for consistency """ def __init__(self, hidden_size: int, eps: float = 1e-6, bias: bool = True): super().__init__(hidden_size, eps=eps, elementwise_affine=True) if not bias: self.bias = None def get_norm_layer(norm_type: str, hidden_size: int, eps: float = 1e-6) -> nn.Module: """ Factory function to get normalization layer Args: norm_type: Type of normalization ('rmsnorm' or 'layernorm') hidden_size: Size of hidden dimension eps: Epsilon for numerical stability Returns: Normalization layer """ if norm_type.lower() == "rmsnorm": return RMSNorm(hidden_size, eps) elif norm_type.lower() == "layernorm": return LayerNorm(hidden_size, eps) else: raise ValueError(f"Unknown norm_type: {norm_type}. Use 'rmsnorm' or 'layernorm'")