Complete transformer LLM built from scratch with: Core Features: - Full transformer architecture (RoPE, RMSNorm, SwiGLU, KV-cache) - SentencePiece tokenizer (BPE/Unigram) - Training pipeline (AMP, gradient checkpointing, DDP) - Persona system with personality matrix (NO AI disclosure by default) - Genetic evolution (NOVA-EVO) for hyperparameter optimization - Legal-only data pipeline with license tracking - Chat interface (CLI + REST API) - Conversation memory (SQLite) Model Sizes: - 125M, 350M, 1.3B, 3B parameters - Local-first, runs on CPU or GPU - Python 3.10.6+, PyTorch 2.0+ Personas: - girlfriend_gentle (high warmth, high empathy) - girlfriend_playful (high humor, high playfulness) - girlfriend_supportive (balanced, default) Documentation: - Complete README with quickstart - Model card with ethical considerations - Privacy documentation (local-first, zero telemetry) - Data licenses and attribution - Contributing guide Infrastructure: - GitHub Actions CI/CD - Comprehensive test suite - Quickstart script - CLI tool License: Apache 2.0 🤖 Generated with Claude Code https://claude.com/claude-code Co-Authored-By: Claude <noreply@anthropic.com>
210 lines
7.5 KiB
Python
210 lines
7.5 KiB
Python
"""
|
|
Multi-head attention with KV-cache and optional Flash Attention
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Optional, Tuple
|
|
import math
|
|
|
|
try:
|
|
from flash_attn import flash_attn_func
|
|
FLASH_ATTENTION_AVAILABLE = True
|
|
except ImportError:
|
|
FLASH_ATTENTION_AVAILABLE = False
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
"""
|
|
Multi-head attention with support for:
|
|
- Grouped-query attention (GQA)
|
|
- KV-cache for fast inference
|
|
- Flash Attention (when available)
|
|
- RoPE/ALiBi positional encoding
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
|
|
assert self.hidden_size % self.num_heads == 0, \
|
|
f"hidden_size must be divisible by num_heads"
|
|
|
|
# Projections
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
|
|
self.dropout = nn.Dropout(config.attention_dropout)
|
|
|
|
# Flash attention flag
|
|
self.use_flash = config.use_flash_attention and FLASH_ATTENTION_AVAILABLE
|
|
|
|
def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
"""
|
|
Repeat key/value tensors for grouped-query attention
|
|
This is equivalent to torch.repeat_interleave(hidden_states, n_rep, dim=1)
|
|
but is more efficient
|
|
"""
|
|
if n_rep == 1:
|
|
return hidden_states
|
|
|
|
batch, num_kv_heads, seq_len, head_dim = hidden_states.shape
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(
|
|
batch, num_kv_heads, n_rep, seq_len, head_dim
|
|
)
|
|
return hidden_states.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
use_cache: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
"""
|
|
Args:
|
|
hidden_states: [batch, seq_len, hidden_size]
|
|
attention_mask: [batch, 1, seq_len, seq_len] or [batch, 1, 1, seq_len]
|
|
position_embeddings: Optional (cos, sin) for RoPE
|
|
past_key_value: Optional cached (key, value) for inference
|
|
use_cache: Whether to return key/value for caching
|
|
|
|
Returns:
|
|
(output, past_key_value if use_cache else None)
|
|
"""
|
|
batch_size, seq_len, _ = hidden_states.shape
|
|
|
|
# Project to Q, K, V
|
|
query = self.q_proj(hidden_states)
|
|
key = self.k_proj(hidden_states)
|
|
value = self.v_proj(hidden_states)
|
|
|
|
# Reshape for multi-head attention
|
|
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key = key.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value = value.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
# Apply rotary embeddings if provided
|
|
if position_embeddings is not None:
|
|
cos, sin = position_embeddings
|
|
query, key = self._apply_rotary_pos_emb(query, key, cos, sin)
|
|
|
|
# Use cached key/value if available
|
|
if past_key_value is not None:
|
|
key = torch.cat([past_key_value[0], key], dim=2)
|
|
value = torch.cat([past_key_value[1], value], dim=2)
|
|
|
|
# Store for next iteration if caching
|
|
if use_cache:
|
|
past_key_value = (key, value)
|
|
else:
|
|
past_key_value = None
|
|
|
|
# Repeat K/V for grouped-query attention
|
|
key = self._repeat_kv(key, self.num_key_value_groups)
|
|
value = self._repeat_kv(value, self.num_key_value_groups)
|
|
|
|
# Compute attention
|
|
if self.use_flash and self.training:
|
|
# Flash Attention (only during training, requires specific format)
|
|
# Flash attention expects [batch, seq_len, num_heads, head_dim]
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
|
|
attn_output = flash_attn_func(
|
|
query, key, value,
|
|
dropout_p=self.config.attention_dropout if self.training else 0.0,
|
|
causal=True
|
|
)
|
|
attn_output = attn_output.transpose(1, 2)
|
|
else:
|
|
# Standard scaled dot-product attention
|
|
attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
|
|
# Apply attention mask
|
|
if attention_mask is not None:
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
|
attn_weights = self.dropout(attn_weights)
|
|
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
|
|
# Reshape and project output
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output, past_key_value
|
|
|
|
def _apply_rotary_pos_emb(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Apply rotary position embeddings"""
|
|
# Rotate half trick for efficiency
|
|
def rotate_half(x):
|
|
x1, x2 = x.chunk(2, dim=-1)
|
|
return torch.cat([-x2, x1], dim=-1)
|
|
|
|
query_rot = (query * cos) + (rotate_half(query) * sin)
|
|
key_rot = (key * cos) + (rotate_half(key) * sin)
|
|
|
|
return query_rot, key_rot
|
|
|
|
|
|
def create_causal_mask(seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
|
"""
|
|
Create causal attention mask for autoregressive generation
|
|
|
|
Args:
|
|
seq_len: Sequence length
|
|
device: Device to create tensor on
|
|
dtype: Data type
|
|
|
|
Returns:
|
|
Causal mask [1, 1, seq_len, seq_len]
|
|
"""
|
|
mask = torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=dtype), diagonal=1)
|
|
mask = mask.masked_fill(mask == 1, float('-inf'))
|
|
return mask.unsqueeze(0).unsqueeze(0)
|
|
|
|
|
|
def create_attention_mask_from_padding(
|
|
input_ids: torch.Tensor,
|
|
pad_token_id: int
|
|
) -> torch.Tensor:
|
|
"""
|
|
Create attention mask from padding tokens
|
|
|
|
Args:
|
|
input_ids: [batch, seq_len]
|
|
pad_token_id: ID of padding token
|
|
|
|
Returns:
|
|
Attention mask [batch, 1, 1, seq_len]
|
|
"""
|
|
# Create padding mask [batch, seq_len]
|
|
padding_mask = (input_ids != pad_token_id).float()
|
|
|
|
# Expand to attention mask format
|
|
attention_mask = padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
|
|
|
|
# Convert to additive mask (0 for attend, -inf for ignore)
|
|
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
|
|
|
return attention_mask
|