""" Multi-head attention with KV-cache and optional Flash Attention """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple import math try: from flash_attn import flash_attn_func FLASH_ATTENTION_AVAILABLE = True except ImportError: FLASH_ATTENTION_AVAILABLE = False class MultiHeadAttention(nn.Module): """ Multi-head attention with support for: - Grouped-query attention (GQA) - KV-cache for fast inference - Flash Attention (when available) - RoPE/ALiBi positional encoding """ def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads assert self.hidden_size % self.num_heads == 0, \ f"hidden_size must be divisible by num_heads" # Projections self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.dropout = nn.Dropout(config.attention_dropout) # Flash attention flag self.use_flash = config.use_flash_attention and FLASH_ATTENTION_AVAILABLE def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ Repeat key/value tensors for grouped-query attention This is equivalent to torch.repeat_interleave(hidden_states, n_rep, dim=1) but is more efficient """ if n_rep == 1: return hidden_states batch, num_kv_heads, seq_len, head_dim = hidden_states.shape hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_kv_heads, n_rep, seq_len, head_dim ) return hidden_states.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: """ Args: hidden_states: [batch, seq_len, hidden_size] attention_mask: [batch, 1, seq_len, seq_len] or [batch, 1, 1, seq_len] position_embeddings: Optional (cos, sin) for RoPE past_key_value: Optional cached (key, value) for inference use_cache: Whether to return key/value for caching Returns: (output, past_key_value if use_cache else None) """ batch_size, seq_len, _ = hidden_states.shape # Project to Q, K, V query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) # Reshape for multi-head attention query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) key = key.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value = value.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # Apply rotary embeddings if provided if position_embeddings is not None: cos, sin = position_embeddings query, key = self._apply_rotary_pos_emb(query, key, cos, sin) # Use cached key/value if available if past_key_value is not None: key = torch.cat([past_key_value[0], key], dim=2) value = torch.cat([past_key_value[1], value], dim=2) # Store for next iteration if caching if use_cache: past_key_value = (key, value) else: past_key_value = None # Repeat K/V for grouped-query attention key = self._repeat_kv(key, self.num_key_value_groups) value = self._repeat_kv(value, self.num_key_value_groups) # Compute attention if self.use_flash and self.training: # Flash Attention (only during training, requires specific format) # Flash attention expects [batch, seq_len, num_heads, head_dim] query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) attn_output = flash_attn_func( query, key, value, dropout_p=self.config.attention_dropout if self.training else 0.0, causal=True ) attn_output = attn_output.transpose(1, 2) else: # Standard scaled dot-product attention attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim) # Apply attention mask if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, value) # Reshape and project output attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, seq_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, past_key_value def _apply_rotary_pos_emb( self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply rotary position embeddings""" # Rotate half trick for efficiency def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) query_rot = (query * cos) + (rotate_half(query) * sin) key_rot = (key * cos) + (rotate_half(key) * sin) return query_rot, key_rot def create_causal_mask(seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: """ Create causal attention mask for autoregressive generation Args: seq_len: Sequence length device: Device to create tensor on dtype: Data type Returns: Causal mask [1, 1, seq_len, seq_len] """ mask = torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=dtype), diagonal=1) mask = mask.masked_fill(mask == 1, float('-inf')) return mask.unsqueeze(0).unsqueeze(0) def create_attention_mask_from_padding( input_ids: torch.Tensor, pad_token_id: int ) -> torch.Tensor: """ Create attention mask from padding tokens Args: input_ids: [batch, seq_len] pad_token_id: ID of padding token Returns: Attention mask [batch, 1, 1, seq_len] """ # Create padding mask [batch, seq_len] padding_mask = (input_ids != pad_token_id).float() # Expand to attention mask format attention_mask = padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len] # Convert to additive mask (0 for attend, -inf for ignore) attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min return attention_mask