Initial commit: NOVA - Neuro-Optimizing Versatile Agent
Complete transformer LLM built from scratch with: Core Features: - Full transformer architecture (RoPE, RMSNorm, SwiGLU, KV-cache) - SentencePiece tokenizer (BPE/Unigram) - Training pipeline (AMP, gradient checkpointing, DDP) - Persona system with personality matrix (NO AI disclosure by default) - Genetic evolution (NOVA-EVO) for hyperparameter optimization - Legal-only data pipeline with license tracking - Chat interface (CLI + REST API) - Conversation memory (SQLite) Model Sizes: - 125M, 350M, 1.3B, 3B parameters - Local-first, runs on CPU or GPU - Python 3.10.6+, PyTorch 2.0+ Personas: - girlfriend_gentle (high warmth, high empathy) - girlfriend_playful (high humor, high playfulness) - girlfriend_supportive (balanced, default) Documentation: - Complete README with quickstart - Model card with ethical considerations - Privacy documentation (local-first, zero telemetry) - Data licenses and attribution - Contributing guide Infrastructure: - GitHub Actions CI/CD - Comprehensive test suite - Quickstart script - CLI tool License: Apache 2.0 🤖 Generated with Claude Code https://claude.com/claude-code Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
209
nova_core/attention.py
Normal file
209
nova_core/attention.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
Multi-head attention with KV-cache and optional Flash Attention
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
import math
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
FLASH_ATTENTION_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASH_ATTENTION_AVAILABLE = False
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
Multi-head attention with support for:
|
||||
- Grouped-query attention (GQA)
|
||||
- KV-cache for fast inference
|
||||
- Flash Attention (when available)
|
||||
- RoPE/ALiBi positional encoding
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
|
||||
assert self.hidden_size % self.num_heads == 0, \
|
||||
f"hidden_size must be divisible by num_heads"
|
||||
|
||||
# Projections
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_dropout)
|
||||
|
||||
# Flash attention flag
|
||||
self.use_flash = config.use_flash_attention and FLASH_ATTENTION_AVAILABLE
|
||||
|
||||
def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
Repeat key/value tensors for grouped-query attention
|
||||
This is equivalent to torch.repeat_interleave(hidden_states, n_rep, dim=1)
|
||||
but is more efficient
|
||||
"""
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
|
||||
batch, num_kv_heads, seq_len, head_dim = hidden_states.shape
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_kv_heads, n_rep, seq_len, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
attention_mask: [batch, 1, seq_len, seq_len] or [batch, 1, 1, seq_len]
|
||||
position_embeddings: Optional (cos, sin) for RoPE
|
||||
past_key_value: Optional cached (key, value) for inference
|
||||
use_cache: Whether to return key/value for caching
|
||||
|
||||
Returns:
|
||||
(output, past_key_value if use_cache else None)
|
||||
"""
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# Project to Q, K, V
|
||||
query = self.q_proj(hidden_states)
|
||||
key = self.k_proj(hidden_states)
|
||||
value = self.v_proj(hidden_states)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary embeddings if provided
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query, key = self._apply_rotary_pos_emb(query, key, cos, sin)
|
||||
|
||||
# Use cached key/value if available
|
||||
if past_key_value is not None:
|
||||
key = torch.cat([past_key_value[0], key], dim=2)
|
||||
value = torch.cat([past_key_value[1], value], dim=2)
|
||||
|
||||
# Store for next iteration if caching
|
||||
if use_cache:
|
||||
past_key_value = (key, value)
|
||||
else:
|
||||
past_key_value = None
|
||||
|
||||
# Repeat K/V for grouped-query attention
|
||||
key = self._repeat_kv(key, self.num_key_value_groups)
|
||||
value = self._repeat_kv(value, self.num_key_value_groups)
|
||||
|
||||
# Compute attention
|
||||
if self.use_flash and self.training:
|
||||
# Flash Attention (only during training, requires specific format)
|
||||
# Flash attention expects [batch, seq_len, num_heads, head_dim]
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
query, key, value,
|
||||
dropout_p=self.config.attention_dropout if self.training else 0.0,
|
||||
causal=True
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
else:
|
||||
# Standard scaled dot-product attention
|
||||
attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
# Apply attention mask
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = self.dropout(attn_weights)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
# Reshape and project output
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, past_key_value
|
||||
|
||||
def _apply_rotary_pos_emb(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply rotary position embeddings"""
|
||||
# Rotate half trick for efficiency
|
||||
def rotate_half(x):
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat([-x2, x1], dim=-1)
|
||||
|
||||
query_rot = (query * cos) + (rotate_half(query) * sin)
|
||||
key_rot = (key * cos) + (rotate_half(key) * sin)
|
||||
|
||||
return query_rot, key_rot
|
||||
|
||||
|
||||
def create_causal_mask(seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Create causal attention mask for autoregressive generation
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
device: Device to create tensor on
|
||||
dtype: Data type
|
||||
|
||||
Returns:
|
||||
Causal mask [1, 1, seq_len, seq_len]
|
||||
"""
|
||||
mask = torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=dtype), diagonal=1)
|
||||
mask = mask.masked_fill(mask == 1, float('-inf'))
|
||||
return mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
|
||||
def create_attention_mask_from_padding(
|
||||
input_ids: torch.Tensor,
|
||||
pad_token_id: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create attention mask from padding tokens
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
pad_token_id: ID of padding token
|
||||
|
||||
Returns:
|
||||
Attention mask [batch, 1, 1, seq_len]
|
||||
"""
|
||||
# Create padding mask [batch, seq_len]
|
||||
padding_mask = (input_ids != pad_token_id).float()
|
||||
|
||||
# Expand to attention mask format
|
||||
attention_mask = padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
|
||||
|
||||
# Convert to additive mask (0 for attend, -inf for ignore)
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
||||
|
||||
return attention_mask
|
Reference in New Issue
Block a user