import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, Tuple, Dict, Any class SelfEvolvingAttention(nn.Module): """ Advanced attention mechanism that can evolve its attention patterns based on conversation context and emotional state. """ def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.1, bias: bool = True, evolution_rate: float = 0.001 ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.evolution_rate = evolution_rate assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" # Standard attention components self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) # Evolution components self.attention_evolution = nn.Parameter(torch.zeros(num_heads, 64, 64)) self.emotional_attention_bias = nn.Parameter(torch.zeros(num_heads, 1, 1)) self.context_adaptation = nn.Linear(embed_dim, num_heads) # Memory for attention patterns self.register_buffer('attention_memory', torch.zeros(num_heads, 100, 100)) self.register_buffer('memory_pointer', torch.zeros(1, dtype=torch.long)) self.dropout = nn.Dropout(dropout) self.scale = math.sqrt(self.head_dim) self._init_parameters() def _init_parameters(self): """Initialize parameters with careful scaling for evolution.""" nn.init.xavier_uniform_(self.q_proj.weight) nn.init.xavier_uniform_(self.k_proj.weight) nn.init.xavier_uniform_(self.v_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight) if self.q_proj.bias is not None: nn.init.constant_(self.q_proj.bias, 0.) nn.init.constant_(self.k_proj.bias, 0.) nn.init.constant_(self.v_proj.bias, 0.) nn.init.constant_(self.out_proj.bias, 0.) # Initialize evolution parameters small nn.init.normal_(self.attention_evolution, std=0.01) nn.init.zeros_(self.emotional_attention_bias) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, emotional_state: Optional[torch.Tensor] = None, evolve: bool = True ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Forward pass with attention evolution. Args: query: Query tensor [batch, seq_len, embed_dim] key: Key tensor [batch, seq_len, embed_dim] value: Value tensor [batch, seq_len, embed_dim] attn_mask: Attention mask key_padding_mask: Key padding mask emotional_state: Current emotional state [batch, emotion_dim] evolve: Whether to apply evolution this step Returns: output: Attention output attention_weights: Attention weights evolution_info: Information about evolution """ batch_size, seq_len, _ = query.shape # Project to Q, K, V q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) # Reshape for multi-head attention q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Compute base attention scores scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # Apply evolution to attention patterns evolution_info = {} if evolve and seq_len <= 64: # Only evolve for reasonable sequence lengths # Get context-aware evolution weights context_weights = self.context_adaptation(query.mean(dim=1)) # [batch, num_heads] context_weights = torch.sigmoid(context_weights).unsqueeze(-1).unsqueeze(-1) # Apply learned evolution patterns evolution_matrix = self.attention_evolution[:, :seq_len, :seq_len] evolved_scores = scores + context_weights * evolution_matrix.unsqueeze(0) # Apply emotional bias if emotional state is provided if emotional_state is not None: emotional_influence = torch.sigmoid(emotional_state.mean(dim=-1, keepdim=True)) emotional_bias = self.emotional_attention_bias * emotional_influence.unsqueeze(-1).unsqueeze(-1) evolved_scores = evolved_scores + emotional_bias.unsqueeze(0) scores = evolved_scores evolution_info['context_weights'] = context_weights.mean().item() evolution_info['evolution_magnitude'] = evolution_matrix.abs().mean().item() # Apply masks if attn_mask is not None: scores = scores.masked_fill(attn_mask == 0, float('-inf')) if key_padding_mask is not None: scores = scores.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf') ) # Compute attention weights attention_weights = F.softmax(scores, dim=-1) attention_weights = self.dropout(attention_weights) # Store attention pattern in memory for evolution if evolve and seq_len <= 100: self._store_attention_pattern(attention_weights.detach()) # Apply attention to values output = torch.matmul(attention_weights, v) # Reshape back output = output.transpose(1, 2).contiguous().view( batch_size, seq_len, self.embed_dim ) # Final projection output = self.out_proj(output) return output, attention_weights, evolution_info def _store_attention_pattern(self, attention_weights: torch.Tensor): """Store attention patterns for learning evolution.""" batch_size, num_heads, seq_len, _ = attention_weights.shape if seq_len <= 100: # Average across batch and store avg_attention = attention_weights.mean(dim=0) # [num_heads, seq_len, seq_len] # Update memory buffer pointer = self.memory_pointer.item() memory_size = self.attention_memory.shape[1] if seq_len <= memory_size: self.attention_memory[:, :seq_len, :seq_len] = ( 0.95 * self.attention_memory[:, :seq_len, :seq_len] + 0.05 * avg_attention ) def evolve_attention_patterns(self, feedback_signal: float): """ Evolve attention patterns based on feedback. Args: feedback_signal: Positive for good responses, negative for bad """ with torch.no_grad(): # Use stored attention memory to update evolution matrix memory_influence = self.attention_memory.mean(dim=0) # Average across heads max_size = min(self.attention_evolution.shape[1], memory_influence.shape[0]) # Update evolution matrix based on successful patterns update = feedback_signal * self.evolution_rate * memory_influence[:max_size, :max_size] self.attention_evolution.data[:, :max_size, :max_size] += update.unsqueeze(0) # Clamp to prevent explosion self.attention_evolution.data = torch.clamp( self.attention_evolution.data, -1.0, 1.0 ) def get_attention_diversity(self) -> float: """Calculate how diverse the attention patterns are (cognitive flexibility).""" with torch.no_grad(): # Calculate entropy of stored attention patterns attention_probs = F.softmax(self.attention_memory, dim=-1) entropy = -torch.sum(attention_probs * torch.log(attention_probs + 1e-8), dim=-1) return entropy.mean().item() class MultiHeadAttention(nn.Module): """ Standard multi-head attention for comparison and fallback. """ def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.1, bias: bool = True ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == embed_dim self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.dropout = nn.Dropout(dropout) self.scale = math.sqrt(self.head_dim) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Standard multi-head attention forward pass.""" batch_size, seq_len, _ = query.shape # Project to Q, K, V q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) # Reshape for multi-head attention q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Compute attention scores scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # Apply masks if attn_mask is not None: scores = scores.masked_fill(attn_mask == 0, float('-inf')) if key_padding_mask is not None: scores = scores.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf') ) # Compute attention weights attention_weights = F.softmax(scores, dim=-1) attention_weights = self.dropout(attention_weights) # Apply attention to values output = torch.matmul(attention_weights, v) # Reshape back output = output.transpose(1, 2).contiguous().view( batch_size, seq_len, self.embed_dim ) # Final projection output = self.out_proj(output) return output, attention_weights