""" Transformer block layers """ import torch import torch.nn as nn from typing import Optional, Tuple from .attention import MultiHeadAttention from .activations import MLP from .normalization import get_norm_layer class TransformerBlock(nn.Module): """ Single transformer decoder block with: - Multi-head attention with RoPE - Feed-forward network (MLP) - Pre-normalization (norm before attention/FFN) - Residual connections """ def __init__(self, config, layer_idx: int): """ Args: config: ModelConfig instance layer_idx: Layer index for identification """ super().__init__() self.config = config self.layer_idx = layer_idx # Attention self.self_attn = MultiHeadAttention(config) self.attn_norm = get_norm_layer( config.norm_type, config.hidden_size, config.rms_norm_eps ) # Feed-forward self.mlp = MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act ) self.mlp_norm = get_norm_layer( config.norm_type, config.hidden_size, config.rms_norm_eps ) # Dropout self.dropout = nn.Dropout(config.hidden_dropout) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: """ Args: hidden_states: [batch, seq_len, hidden_size] attention_mask: Optional attention mask position_embeddings: Optional (cos, sin) for RoPE past_key_value: Optional cached key/value use_cache: Whether to return key/value cache Returns: (hidden_states, past_key_value if use_cache else None) """ residual = hidden_states # Pre-norm for attention hidden_states = self.attn_norm(hidden_states) # Self-attention with KV-cache attn_output, past_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, past_key_value=past_key_value, use_cache=use_cache, ) # Residual connection hidden_states = residual + self.dropout(attn_output) # Feed-forward with pre-norm residual = hidden_states hidden_states = self.mlp_norm(hidden_states) mlp_output = self.mlp(hidden_states) hidden_states = residual + self.dropout(mlp_output) return hidden_states, past_key_value