""" Activation functions for NOVA """ import torch import torch.nn as nn import torch.nn.functional as F class SwiGLU(nn.Module): """ SwiGLU activation function from Shazeer (2020) Used in PaLM and other modern LLMs SwiGLU(x, W, V, b, c) = Swish(xW + b) ⊗ (xV + c) where Swish(x) = x * sigmoid(x) """ def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False): """ Args: hidden_size: Input dimension intermediate_size: Hidden dimension (usually 4 * hidden_size) bias: Whether to use bias in linear layers """ super().__init__() # Gate projection self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) # Up projection self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) # Down projection self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply SwiGLU activation Args: x: Input tensor [..., hidden_size] Returns: Output tensor [..., hidden_size] """ # Swish activation: x * sigmoid(x) gate = F.silu(self.gate_proj(x)) # Element-wise multiplication with up projection up = self.up_proj(x) # Down projection return self.down_proj(gate * up) class GeGLU(nn.Module): """ GeGLU activation function - variant of SwiGLU using GELU GeGLU(x, W, V) = GELU(xW) ⊗ (xV) """ def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False): """ Args: hidden_size: Input dimension intermediate_size: Hidden dimension bias: Whether to use bias in linear layers """ super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply GeGLU activation""" gate = F.gelu(self.gate_proj(x), approximate="tanh") up = self.up_proj(x) return self.down_proj(gate * up) class MLP(nn.Module): """ Standard MLP with configurable activation """ def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str = "swiglu", bias: bool = False ): """ Args: hidden_size: Input/output dimension intermediate_size: Hidden dimension hidden_act: Activation function ('swiglu', 'geglu', or 'gelu') bias: Whether to use bias """ super().__init__() if hidden_act.lower() == "swiglu": self.mlp = SwiGLU(hidden_size, intermediate_size, bias) elif hidden_act.lower() == "geglu": self.mlp = GeGLU(hidden_size, intermediate_size, bias) elif hidden_act.lower() == "gelu": # Standard GELU MLP self.mlp = nn.Sequential( nn.Linear(hidden_size, intermediate_size, bias=bias), nn.GELU(approximate="tanh"), nn.Linear(intermediate_size, hidden_size, bias=bias) ) else: raise ValueError(f"Unknown activation: {hidden_act}") def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through MLP""" return self.mlp(x)