Files
NOVA/nova_core/attention.py
Dani a7f091aa45 Initial commit: NOVA - Neuro-Optimizing Versatile Agent
Complete transformer LLM built from scratch with:

Core Features:
- Full transformer architecture (RoPE, RMSNorm, SwiGLU, KV-cache)
- SentencePiece tokenizer (BPE/Unigram)
- Training pipeline (AMP, gradient checkpointing, DDP)
- Persona system with personality matrix (NO AI disclosure by default)
- Genetic evolution (NOVA-EVO) for hyperparameter optimization
- Legal-only data pipeline with license tracking
- Chat interface (CLI + REST API)
- Conversation memory (SQLite)

Model Sizes:
- 125M, 350M, 1.3B, 3B parameters
- Local-first, runs on CPU or GPU
- Python 3.10.6+, PyTorch 2.0+

Personas:
- girlfriend_gentle (high warmth, high empathy)
- girlfriend_playful (high humor, high playfulness)
- girlfriend_supportive (balanced, default)

Documentation:
- Complete README with quickstart
- Model card with ethical considerations
- Privacy documentation (local-first, zero telemetry)
- Data licenses and attribution
- Contributing guide

Infrastructure:
- GitHub Actions CI/CD
- Comprehensive test suite
- Quickstart script
- CLI tool

License: Apache 2.0

🤖 Generated with Claude Code
https://claude.com/claude-code

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-12 20:56:37 -04:00

210 lines
7.5 KiB
Python

"""
Multi-head attention with KV-cache and optional Flash Attention
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
class MultiHeadAttention(nn.Module):
"""
Multi-head attention with support for:
- Grouped-query attention (GQA)
- KV-cache for fast inference
- Flash Attention (when available)
- RoPE/ALiBi positional encoding
"""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
assert self.hidden_size % self.num_heads == 0, \
f"hidden_size must be divisible by num_heads"
# Projections
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.dropout = nn.Dropout(config.attention_dropout)
# Flash attention flag
self.use_flash = config.use_flash_attention and FLASH_ATTENTION_AVAILABLE
def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Repeat key/value tensors for grouped-query attention
This is equivalent to torch.repeat_interleave(hidden_states, n_rep, dim=1)
but is more efficient
"""
if n_rep == 1:
return hidden_states
batch, num_kv_heads, seq_len, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_kv_heads, n_rep, seq_len, head_dim
)
return hidden_states.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
attention_mask: [batch, 1, seq_len, seq_len] or [batch, 1, 1, seq_len]
position_embeddings: Optional (cos, sin) for RoPE
past_key_value: Optional cached (key, value) for inference
use_cache: Whether to return key/value for caching
Returns:
(output, past_key_value if use_cache else None)
"""
batch_size, seq_len, _ = hidden_states.shape
# Project to Q, K, V
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
# Reshape for multi-head attention
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Apply rotary embeddings if provided
if position_embeddings is not None:
cos, sin = position_embeddings
query, key = self._apply_rotary_pos_emb(query, key, cos, sin)
# Use cached key/value if available
if past_key_value is not None:
key = torch.cat([past_key_value[0], key], dim=2)
value = torch.cat([past_key_value[1], value], dim=2)
# Store for next iteration if caching
if use_cache:
past_key_value = (key, value)
else:
past_key_value = None
# Repeat K/V for grouped-query attention
key = self._repeat_kv(key, self.num_key_value_groups)
value = self._repeat_kv(value, self.num_key_value_groups)
# Compute attention
if self.use_flash and self.training:
# Flash Attention (only during training, requires specific format)
# Flash attention expects [batch, seq_len, num_heads, head_dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_output = flash_attn_func(
query, key, value,
dropout_p=self.config.attention_dropout if self.training else 0.0,
causal=True
)
attn_output = attn_output.transpose(1, 2)
else:
# Standard scaled dot-product attention
attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Apply attention mask
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value)
# Reshape and project output
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, past_key_value
def _apply_rotary_pos_emb(
self,
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary position embeddings"""
# Rotate half trick for efficiency
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
query_rot = (query * cos) + (rotate_half(query) * sin)
key_rot = (key * cos) + (rotate_half(key) * sin)
return query_rot, key_rot
def create_causal_mask(seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""
Create causal attention mask for autoregressive generation
Args:
seq_len: Sequence length
device: Device to create tensor on
dtype: Data type
Returns:
Causal mask [1, 1, seq_len, seq_len]
"""
mask = torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=dtype), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask.unsqueeze(0).unsqueeze(0)
def create_attention_mask_from_padding(
input_ids: torch.Tensor,
pad_token_id: int
) -> torch.Tensor:
"""
Create attention mask from padding tokens
Args:
input_ids: [batch, seq_len]
pad_token_id: ID of padding token
Returns:
Attention mask [batch, 1, 1, seq_len]
"""
# Create padding mask [batch, seq_len]
padding_mask = (input_ids != pad_token_id).float()
# Expand to attention mask format
attention_mask = padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
# Convert to additive mask (0 for attend, -inf for ignore)
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
return attention_mask