""" Rotary Position Embedding (RoPE) implementation """ import torch import torch.nn as nn from typing import Tuple class RotaryPositionalEmbedding(nn.Module): """ Rotary Position Embedding (RoPE) from Su et al. (2021) https://arxiv.org/abs/2104.09864 """ def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0): """ Args: dim: Dimension of the embeddings (should be head_dim) max_seq_len: Maximum sequence length theta: Base for the geometric progression (default 10000.0) """ super().__init__() self.dim = dim self.max_seq_len = max_seq_len self.theta = theta # Precompute frequencies inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Precompute cos/sin cache self._update_cos_sin_cache(max_seq_len) def _update_cos_sin_cache(self, seq_len: int): """Precompute cos and sin for positions up to seq_len""" position = torch.arange(seq_len).unsqueeze(1) freqs = position * self.inv_freq.unsqueeze(0) # Create rotation matrix [seq_len, dim/2] emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) self.cached_seq_len = seq_len def rotate_half(self, x: torch.Tensor) -> torch.Tensor: """Rotates half the hidden dims of the input""" x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def forward( self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary position embeddings to query and key tensors Args: q: Query tensor [batch, num_heads, seq_len, head_dim] k: Key tensor [batch, num_heads, seq_len, head_dim] position_ids: Optional position IDs [batch, seq_len] Returns: Tuple of rotated query and key tensors """ seq_len = q.shape[2] # Update cache if needed if seq_len > self.cached_seq_len: self._update_cos_sin_cache(seq_len) # Get cos/sin for current positions if position_ids is not None: # For generation with KV-cache cos = self.cos_cached[position_ids].unsqueeze(1) sin = self.sin_cached[position_ids].unsqueeze(1) else: # For training or initial forward pass cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0) sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0) # Apply rotation q_embed = (q * cos) + (self.rotate_half(q) * sin) k_embed = (k * cos) + (self.rotate_half(k) * sin) return q_embed, k_embed class ALiBiPositionalBias(nn.Module): """ Attention with Linear Biases (ALiBi) from Press et al. (2021) https://arxiv.org/abs/2108.12409 Alternative to RoPE """ def __init__(self, num_heads: int, max_seq_len: int = 2048): """ Args: num_heads: Number of attention heads max_seq_len: Maximum sequence length """ super().__init__() self.num_heads = num_heads self.max_seq_len = max_seq_len # Compute slopes for each head slopes = self._get_slopes(num_heads) self.register_buffer("slopes", slopes, persistent=False) # Precompute bias matrix alibi = self._get_alibi_bias(max_seq_len, slopes) self.register_buffer("alibi_bias", alibi, persistent=False) def _get_slopes(self, num_heads: int) -> torch.Tensor: """Compute slopes for ALiBi""" def get_slopes_power_of_2(n): start = 2 ** (-(2 ** -(torch.log2(torch.tensor(n)) - 3))) ratio = start return torch.pow(2, torch.arange(n)) * ratio # Handle non-power-of-2 number of heads if (num_heads & (num_heads - 1)) == 0: return get_slopes_power_of_2(num_heads) else: closest_power_of_2 = 2 ** torch.floor(torch.log2(torch.tensor(num_heads))) slopes_a = get_slopes_power_of_2(int(closest_power_of_2)) slopes_b = self._get_slopes(int(2 * closest_power_of_2))[0::2][:num_heads - int(closest_power_of_2)] return torch.cat([slopes_a, slopes_b]) def _get_alibi_bias(self, seq_len: int, slopes: torch.Tensor) -> torch.Tensor: """Precompute ALiBi bias matrix""" # Create relative position matrix pos = torch.arange(seq_len).unsqueeze(0) rel_pos = pos - pos.T # [seq_len, seq_len] # Apply slopes [num_heads, seq_len, seq_len] alibi = rel_pos.unsqueeze(0) * slopes.unsqueeze(-1).unsqueeze(-1) return alibi def forward(self, attention_scores: torch.Tensor, seq_len: int) -> torch.Tensor: """ Add ALiBi bias to attention scores Args: attention_scores: [batch, num_heads, seq_len, seq_len] seq_len: Current sequence length Returns: Biased attention scores """ return attention_scores + self.alibi_bias[:, :seq_len, :seq_len]