Initial commit: NOVA - Neuro-Optimizing Versatile Agent
Complete transformer LLM built from scratch with: Core Features: - Full transformer architecture (RoPE, RMSNorm, SwiGLU, KV-cache) - SentencePiece tokenizer (BPE/Unigram) - Training pipeline (AMP, gradient checkpointing, DDP) - Persona system with personality matrix (NO AI disclosure by default) - Genetic evolution (NOVA-EVO) for hyperparameter optimization - Legal-only data pipeline with license tracking - Chat interface (CLI + REST API) - Conversation memory (SQLite) Model Sizes: - 125M, 350M, 1.3B, 3B parameters - Local-first, runs on CPU or GPU - Python 3.10.6+, PyTorch 2.0+ Personas: - girlfriend_gentle (high warmth, high empathy) - girlfriend_playful (high humor, high playfulness) - girlfriend_supportive (balanced, default) Documentation: - Complete README with quickstart - Model card with ethical considerations - Privacy documentation (local-first, zero telemetry) - Data licenses and attribution - Contributing guide Infrastructure: - GitHub Actions CI/CD - Comprehensive test suite - Quickstart script - CLI tool License: Apache 2.0 🤖 Generated with Claude Code https://claude.com/claude-code Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
15
nova_core/__init__.py
Normal file
15
nova_core/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
NOVA Core - Transformer architecture from scratch
|
||||
"""
|
||||
|
||||
from .model import NovaTransformer
|
||||
from .attention import MultiHeadAttention
|
||||
from .layers import TransformerBlock
|
||||
from .config import ModelConfig
|
||||
|
||||
__all__ = [
|
||||
'NovaTransformer',
|
||||
'MultiHeadAttention',
|
||||
'TransformerBlock',
|
||||
'ModelConfig',
|
||||
]
|
114
nova_core/activations.py
Normal file
114
nova_core/activations.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Activation functions for NOVA
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
"""
|
||||
SwiGLU activation function from Shazeer (2020)
|
||||
Used in PaLM and other modern LLMs
|
||||
|
||||
SwiGLU(x, W, V, b, c) = Swish(xW + b) ⊗ (xV + c)
|
||||
where Swish(x) = x * sigmoid(x)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
|
||||
"""
|
||||
Args:
|
||||
hidden_size: Input dimension
|
||||
intermediate_size: Hidden dimension (usually 4 * hidden_size)
|
||||
bias: Whether to use bias in linear layers
|
||||
"""
|
||||
super().__init__()
|
||||
# Gate projection
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||
# Up projection
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||
# Down projection
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply SwiGLU activation
|
||||
|
||||
Args:
|
||||
x: Input tensor [..., hidden_size]
|
||||
|
||||
Returns:
|
||||
Output tensor [..., hidden_size]
|
||||
"""
|
||||
# Swish activation: x * sigmoid(x)
|
||||
gate = F.silu(self.gate_proj(x))
|
||||
# Element-wise multiplication with up projection
|
||||
up = self.up_proj(x)
|
||||
# Down projection
|
||||
return self.down_proj(gate * up)
|
||||
|
||||
|
||||
class GeGLU(nn.Module):
|
||||
"""
|
||||
GeGLU activation function - variant of SwiGLU using GELU
|
||||
GeGLU(x, W, V) = GELU(xW) ⊗ (xV)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
|
||||
"""
|
||||
Args:
|
||||
hidden_size: Input dimension
|
||||
intermediate_size: Hidden dimension
|
||||
bias: Whether to use bias in linear layers
|
||||
"""
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply GeGLU activation"""
|
||||
gate = F.gelu(self.gate_proj(x), approximate="tanh")
|
||||
up = self.up_proj(x)
|
||||
return self.down_proj(gate * up)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Standard MLP with configurable activation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str = "swiglu",
|
||||
bias: bool = False
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_size: Input/output dimension
|
||||
intermediate_size: Hidden dimension
|
||||
hidden_act: Activation function ('swiglu', 'geglu', or 'gelu')
|
||||
bias: Whether to use bias
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if hidden_act.lower() == "swiglu":
|
||||
self.mlp = SwiGLU(hidden_size, intermediate_size, bias)
|
||||
elif hidden_act.lower() == "geglu":
|
||||
self.mlp = GeGLU(hidden_size, intermediate_size, bias)
|
||||
elif hidden_act.lower() == "gelu":
|
||||
# Standard GELU MLP
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, intermediate_size, bias=bias),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(intermediate_size, hidden_size, bias=bias)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {hidden_act}")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass through MLP"""
|
||||
return self.mlp(x)
|
209
nova_core/attention.py
Normal file
209
nova_core/attention.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
Multi-head attention with KV-cache and optional Flash Attention
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
import math
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
FLASH_ATTENTION_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASH_ATTENTION_AVAILABLE = False
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
Multi-head attention with support for:
|
||||
- Grouped-query attention (GQA)
|
||||
- KV-cache for fast inference
|
||||
- Flash Attention (when available)
|
||||
- RoPE/ALiBi positional encoding
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
|
||||
assert self.hidden_size % self.num_heads == 0, \
|
||||
f"hidden_size must be divisible by num_heads"
|
||||
|
||||
# Projections
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_dropout)
|
||||
|
||||
# Flash attention flag
|
||||
self.use_flash = config.use_flash_attention and FLASH_ATTENTION_AVAILABLE
|
||||
|
||||
def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
Repeat key/value tensors for grouped-query attention
|
||||
This is equivalent to torch.repeat_interleave(hidden_states, n_rep, dim=1)
|
||||
but is more efficient
|
||||
"""
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
|
||||
batch, num_kv_heads, seq_len, head_dim = hidden_states.shape
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_kv_heads, n_rep, seq_len, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
attention_mask: [batch, 1, seq_len, seq_len] or [batch, 1, 1, seq_len]
|
||||
position_embeddings: Optional (cos, sin) for RoPE
|
||||
past_key_value: Optional cached (key, value) for inference
|
||||
use_cache: Whether to return key/value for caching
|
||||
|
||||
Returns:
|
||||
(output, past_key_value if use_cache else None)
|
||||
"""
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# Project to Q, K, V
|
||||
query = self.q_proj(hidden_states)
|
||||
key = self.k_proj(hidden_states)
|
||||
value = self.v_proj(hidden_states)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary embeddings if provided
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query, key = self._apply_rotary_pos_emb(query, key, cos, sin)
|
||||
|
||||
# Use cached key/value if available
|
||||
if past_key_value is not None:
|
||||
key = torch.cat([past_key_value[0], key], dim=2)
|
||||
value = torch.cat([past_key_value[1], value], dim=2)
|
||||
|
||||
# Store for next iteration if caching
|
||||
if use_cache:
|
||||
past_key_value = (key, value)
|
||||
else:
|
||||
past_key_value = None
|
||||
|
||||
# Repeat K/V for grouped-query attention
|
||||
key = self._repeat_kv(key, self.num_key_value_groups)
|
||||
value = self._repeat_kv(value, self.num_key_value_groups)
|
||||
|
||||
# Compute attention
|
||||
if self.use_flash and self.training:
|
||||
# Flash Attention (only during training, requires specific format)
|
||||
# Flash attention expects [batch, seq_len, num_heads, head_dim]
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
query, key, value,
|
||||
dropout_p=self.config.attention_dropout if self.training else 0.0,
|
||||
causal=True
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
else:
|
||||
# Standard scaled dot-product attention
|
||||
attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
# Apply attention mask
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = self.dropout(attn_weights)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
# Reshape and project output
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, past_key_value
|
||||
|
||||
def _apply_rotary_pos_emb(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply rotary position embeddings"""
|
||||
# Rotate half trick for efficiency
|
||||
def rotate_half(x):
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat([-x2, x1], dim=-1)
|
||||
|
||||
query_rot = (query * cos) + (rotate_half(query) * sin)
|
||||
key_rot = (key * cos) + (rotate_half(key) * sin)
|
||||
|
||||
return query_rot, key_rot
|
||||
|
||||
|
||||
def create_causal_mask(seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Create causal attention mask for autoregressive generation
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
device: Device to create tensor on
|
||||
dtype: Data type
|
||||
|
||||
Returns:
|
||||
Causal mask [1, 1, seq_len, seq_len]
|
||||
"""
|
||||
mask = torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=dtype), diagonal=1)
|
||||
mask = mask.masked_fill(mask == 1, float('-inf'))
|
||||
return mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
|
||||
def create_attention_mask_from_padding(
|
||||
input_ids: torch.Tensor,
|
||||
pad_token_id: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create attention mask from padding tokens
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
pad_token_id: ID of padding token
|
||||
|
||||
Returns:
|
||||
Attention mask [batch, 1, 1, seq_len]
|
||||
"""
|
||||
# Create padding mask [batch, seq_len]
|
||||
padding_mask = (input_ids != pad_token_id).float()
|
||||
|
||||
# Expand to attention mask format
|
||||
attention_mask = padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
|
||||
|
||||
# Convert to additive mask (0 for attend, -inf for ignore)
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
||||
|
||||
return attention_mask
|
94
nova_core/config.py
Normal file
94
nova_core/config.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Model configuration for NOVA transformer
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Configuration for NOVA transformer model"""
|
||||
|
||||
# Model architecture
|
||||
vocab_size: int = 32000
|
||||
hidden_size: int = 768
|
||||
num_hidden_layers: int = 12
|
||||
num_attention_heads: int = 12
|
||||
intermediate_size: int = 3072
|
||||
max_position_embeddings: int = 2048
|
||||
|
||||
# Activation and normalization
|
||||
hidden_act: str = "swiglu" # or "gelu"
|
||||
norm_type: str = "rmsnorm" # or "layernorm"
|
||||
rms_norm_eps: float = 1e-6
|
||||
|
||||
# Positional encoding
|
||||
rope_theta: float = 10000.0
|
||||
use_rope: bool = True
|
||||
use_alibi: bool = False # Alternative to RoPE
|
||||
|
||||
# Attention
|
||||
attention_dropout: float = 0.0
|
||||
hidden_dropout: float = 0.1
|
||||
num_key_value_heads: Optional[int] = None # For grouped-query attention (GQA)
|
||||
use_flash_attention: bool = False # Auto-detected at runtime
|
||||
|
||||
# Training
|
||||
initializer_range: float = 0.02
|
||||
use_cache: bool = True # KV-cache for inference
|
||||
|
||||
# Efficiency
|
||||
gradient_checkpointing: bool = False
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate and set derived values"""
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
assert self.hidden_size % self.num_attention_heads == 0, \
|
||||
f"hidden_size ({self.hidden_size}) must be divisible by num_attention_heads ({self.num_attention_heads})"
|
||||
|
||||
assert self.num_attention_heads % self.num_key_value_heads == 0, \
|
||||
f"num_attention_heads ({self.num_attention_heads}) must be divisible by num_key_value_heads ({self.num_key_value_heads})"
|
||||
|
||||
|
||||
# Predefined model sizes
|
||||
MODEL_125M = ModelConfig(
|
||||
vocab_size=32000,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
max_position_embeddings=2048,
|
||||
)
|
||||
|
||||
MODEL_350M = ModelConfig(
|
||||
vocab_size=32000,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
max_position_embeddings=2048,
|
||||
)
|
||||
|
||||
MODEL_1_3B = ModelConfig(
|
||||
vocab_size=32000,
|
||||
hidden_size=2048,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=32,
|
||||
intermediate_size=8192,
|
||||
max_position_embeddings=2048,
|
||||
num_key_value_heads=8, # GQA for efficiency
|
||||
)
|
||||
|
||||
MODEL_3B = ModelConfig(
|
||||
vocab_size=32000,
|
||||
hidden_size=2560,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
intermediate_size=10240,
|
||||
max_position_embeddings=4096,
|
||||
num_key_value_heads=8, # GQA for efficiency
|
||||
)
|
98
nova_core/layers.py
Normal file
98
nova_core/layers.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Transformer block layers
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from .attention import MultiHeadAttention
|
||||
from .activations import MLP
|
||||
from .normalization import get_norm_layer
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""
|
||||
Single transformer decoder block with:
|
||||
- Multi-head attention with RoPE
|
||||
- Feed-forward network (MLP)
|
||||
- Pre-normalization (norm before attention/FFN)
|
||||
- Residual connections
|
||||
"""
|
||||
|
||||
def __init__(self, config, layer_idx: int):
|
||||
"""
|
||||
Args:
|
||||
config: ModelConfig instance
|
||||
layer_idx: Layer index for identification
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Attention
|
||||
self.self_attn = MultiHeadAttention(config)
|
||||
self.attn_norm = get_norm_layer(
|
||||
config.norm_type,
|
||||
config.hidden_size,
|
||||
config.rms_norm_eps
|
||||
)
|
||||
|
||||
# Feed-forward
|
||||
self.mlp = MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act
|
||||
)
|
||||
self.mlp_norm = get_norm_layer(
|
||||
config.norm_type,
|
||||
config.hidden_size,
|
||||
config.rms_norm_eps
|
||||
)
|
||||
|
||||
# Dropout
|
||||
self.dropout = nn.Dropout(config.hidden_dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
attention_mask: Optional attention mask
|
||||
position_embeddings: Optional (cos, sin) for RoPE
|
||||
past_key_value: Optional cached key/value
|
||||
use_cache: Whether to return key/value cache
|
||||
|
||||
Returns:
|
||||
(hidden_states, past_key_value if use_cache else None)
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
# Pre-norm for attention
|
||||
hidden_states = self.attn_norm(hidden_states)
|
||||
|
||||
# Self-attention with KV-cache
|
||||
attn_output, past_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Residual connection
|
||||
hidden_states = residual + self.dropout(attn_output)
|
||||
|
||||
# Feed-forward with pre-norm
|
||||
residual = hidden_states
|
||||
hidden_states = self.mlp_norm(hidden_states)
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
hidden_states = residual + self.dropout(mlp_output)
|
||||
|
||||
return hidden_states, past_key_value
|
335
nova_core/model.py
Normal file
335
nova_core/model.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
NOVA Transformer - Main model implementation
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple, List
|
||||
import math
|
||||
|
||||
from .config import ModelConfig
|
||||
from .layers import TransformerBlock
|
||||
from .rope import RotaryPositionalEmbedding, ALiBiPositionalBias
|
||||
from .normalization import get_norm_layer
|
||||
from .attention import create_causal_mask
|
||||
|
||||
|
||||
class NovaTransformer(nn.Module):
|
||||
"""
|
||||
NOVA Transformer Language Model
|
||||
|
||||
A decoder-only transformer with:
|
||||
- RoPE or ALiBi positional encoding
|
||||
- RMSNorm or LayerNorm
|
||||
- SwiGLU or GELU activations
|
||||
- Grouped-query attention (optional)
|
||||
- KV-cache for fast inference
|
||||
- Gradient checkpointing support
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
# Token embeddings
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
# Positional encoding
|
||||
if config.use_rope:
|
||||
self.rope = RotaryPositionalEmbedding(
|
||||
dim=config.hidden_size // config.num_attention_heads,
|
||||
max_seq_len=config.max_position_embeddings,
|
||||
theta=config.rope_theta
|
||||
)
|
||||
elif config.use_alibi:
|
||||
self.alibi = ALiBiPositionalBias(
|
||||
num_heads=config.num_attention_heads,
|
||||
max_seq_len=config.max_position_embeddings
|
||||
)
|
||||
else:
|
||||
self.rope = None
|
||||
self.alibi = None
|
||||
|
||||
# Transformer blocks
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerBlock(config, layer_idx=i)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
# Final layer norm
|
||||
self.norm = get_norm_layer(
|
||||
config.norm_type,
|
||||
config.hidden_size,
|
||||
config.rms_norm_eps
|
||||
)
|
||||
|
||||
# Language model head
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Tie weights if specified
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.embed_tokens.weight
|
||||
|
||||
# Gradient checkpointing
|
||||
self.gradient_checkpointing = config.gradient_checkpointing
|
||||
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize weights using normal distribution"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
past_key_values_length: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create causal attention mask for decoder
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
past_key_values_length: Length of cached keys/values
|
||||
|
||||
Returns:
|
||||
Causal attention mask
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
dtype = torch.float32
|
||||
|
||||
# Create causal mask
|
||||
if past_key_values_length > 0:
|
||||
# During generation, only mask the new token
|
||||
mask = torch.zeros(
|
||||
(batch_size, 1, seq_len, past_key_values_length + seq_len),
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
else:
|
||||
# During training, mask future tokens
|
||||
mask = create_causal_mask(seq_len, device, dtype)
|
||||
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Forward pass through NOVA transformer
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
attention_mask: Optional custom attention mask
|
||||
past_key_values: Optional cached key/values for generation
|
||||
use_cache: Whether to return key/value cache
|
||||
return_dict: Whether to return dict or tuple
|
||||
|
||||
Returns:
|
||||
ModelOutput with logits and optional cache
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
# Get past sequence length for KV-cache
|
||||
past_key_values_length = 0
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Embed tokens
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Prepare attention mask
|
||||
if attention_mask is None:
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
input_ids,
|
||||
past_key_values_length
|
||||
)
|
||||
|
||||
# Prepare position embeddings for RoPE
|
||||
position_embeddings = None
|
||||
if self.rope is not None:
|
||||
# Create position IDs
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_len + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=input_ids.device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Get cos/sin embeddings
|
||||
cos = self.rope.cos_cached[position_ids].unsqueeze(1)
|
||||
sin = self.rope.sin_cached[position_ids].unsqueeze(1)
|
||||
position_embeddings = (cos, sin)
|
||||
|
||||
# Pass through transformer blocks
|
||||
next_cache = [] if use_cache else None
|
||||
|
||||
for idx, layer in enumerate(self.layers):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
# Use gradient checkpointing during training
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_embeddings,
|
||||
past_key_value,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_cache.append(layer_outputs[1])
|
||||
|
||||
# Final layer norm
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# LM head
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if return_dict:
|
||||
return {
|
||||
'logits': logits,
|
||||
'past_key_values': next_cache if use_cache else None,
|
||||
'hidden_states': hidden_states,
|
||||
}
|
||||
else:
|
||||
return (logits, next_cache if use_cache else None)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
max_new_tokens: int = 100,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: float = 1.0,
|
||||
do_sample: bool = True,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate text using the model
|
||||
|
||||
Args:
|
||||
input_ids: [batch, seq_len] starting tokens
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature (higher = more random)
|
||||
top_k: Keep only top k tokens for sampling
|
||||
top_p: Nucleus sampling - keep top tokens with cumulative probability p
|
||||
repetition_penalty: Penalty for repeating tokens (>1.0 discourages)
|
||||
do_sample: Whether to sample (True) or use greedy decoding (False)
|
||||
eos_token_id: Token ID that ends generation
|
||||
|
||||
Returns:
|
||||
Generated token IDs [batch, seq_len + new_tokens]
|
||||
"""
|
||||
self.eval()
|
||||
device = input_ids.device
|
||||
past_key_values = None
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
# Forward pass with cache
|
||||
outputs = self.forward(
|
||||
input_ids=input_ids if past_key_values is None else input_ids[:, -1:],
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
logits = outputs['logits'][:, -1, :] # [batch, vocab_size]
|
||||
past_key_values = outputs['past_key_values']
|
||||
|
||||
# Apply repetition penalty
|
||||
if repetition_penalty != 1.0:
|
||||
for token_id in set(input_ids[0].tolist()):
|
||||
logits[0, token_id] /= repetition_penalty
|
||||
|
||||
# Apply temperature
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# Top-k filtering
|
||||
if top_k is not None:
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = float('-inf')
|
||||
|
||||
# Top-p (nucleus) filtering
|
||||
if top_p is not None:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
logits[indices_to_remove] = float('-inf')
|
||||
|
||||
# Sample or greedy decode
|
||||
if do_sample:
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
|
||||
# Append to sequence
|
||||
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
||||
|
||||
# Check for EOS
|
||||
if eos_token_id is not None and next_token.item() == eos_token_id:
|
||||
break
|
||||
|
||||
return input_ids
|
||||
|
||||
def get_num_params(self, non_embedding: bool = False) -> int:
|
||||
"""
|
||||
Get number of parameters in the model
|
||||
|
||||
Args:
|
||||
non_embedding: If True, exclude embedding parameters
|
||||
|
||||
Returns:
|
||||
Number of parameters
|
||||
"""
|
||||
n_params = sum(p.numel() for p in self.parameters())
|
||||
if non_embedding:
|
||||
n_params -= self.embed_tokens.weight.numel()
|
||||
return n_params
|
74
nova_core/normalization.py
Normal file
74
nova_core/normalization.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Normalization layers for NOVA
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""
|
||||
Root Mean Square Layer Normalization
|
||||
More efficient than LayerNorm, used in LLaMA and other modern LLMs
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
"""
|
||||
Args:
|
||||
hidden_size: Size of the hidden dimension
|
||||
eps: Small constant for numerical stability
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply RMS normalization
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor [..., hidden_size]
|
||||
|
||||
Returns:
|
||||
Normalized tensor
|
||||
"""
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# Compute RMS
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""
|
||||
Standard LayerNorm with optional bias
|
||||
Wrapper around PyTorch's LayerNorm for consistency
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6, bias: bool = True):
|
||||
super().__init__(hidden_size, eps=eps, elementwise_affine=True)
|
||||
if not bias:
|
||||
self.bias = None
|
||||
|
||||
|
||||
def get_norm_layer(norm_type: str, hidden_size: int, eps: float = 1e-6) -> nn.Module:
|
||||
"""
|
||||
Factory function to get normalization layer
|
||||
|
||||
Args:
|
||||
norm_type: Type of normalization ('rmsnorm' or 'layernorm')
|
||||
hidden_size: Size of hidden dimension
|
||||
eps: Epsilon for numerical stability
|
||||
|
||||
Returns:
|
||||
Normalization layer
|
||||
"""
|
||||
if norm_type.lower() == "rmsnorm":
|
||||
return RMSNorm(hidden_size, eps)
|
||||
elif norm_type.lower() == "layernorm":
|
||||
return LayerNorm(hidden_size, eps)
|
||||
else:
|
||||
raise ValueError(f"Unknown norm_type: {norm_type}. Use 'rmsnorm' or 'layernorm'")
|
155
nova_core/rope.py
Normal file
155
nova_core/rope.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
Rotary Position Embedding (RoPE) implementation
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding(nn.Module):
|
||||
"""
|
||||
Rotary Position Embedding (RoPE) from Su et al. (2021)
|
||||
https://arxiv.org/abs/2104.09864
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
|
||||
"""
|
||||
Args:
|
||||
dim: Dimension of the embeddings (should be head_dim)
|
||||
max_seq_len: Maximum sequence length
|
||||
theta: Base for the geometric progression (default 10000.0)
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_seq_len = max_seq_len
|
||||
self.theta = theta
|
||||
|
||||
# Precompute frequencies
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Precompute cos/sin cache
|
||||
self._update_cos_sin_cache(max_seq_len)
|
||||
|
||||
def _update_cos_sin_cache(self, seq_len: int):
|
||||
"""Precompute cos and sin for positions up to seq_len"""
|
||||
position = torch.arange(seq_len).unsqueeze(1)
|
||||
freqs = position * self.inv_freq.unsqueeze(0)
|
||||
|
||||
# Create rotation matrix [seq_len, dim/2]
|
||||
emb = torch.cat([freqs, freqs], dim=-1)
|
||||
|
||||
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||
self.cached_seq_len = seq_len
|
||||
|
||||
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Rotates half the hidden dims of the input"""
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat([-x2, x1], dim=-1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
position_ids: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary position embeddings to query and key tensors
|
||||
|
||||
Args:
|
||||
q: Query tensor [batch, num_heads, seq_len, head_dim]
|
||||
k: Key tensor [batch, num_heads, seq_len, head_dim]
|
||||
position_ids: Optional position IDs [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
Tuple of rotated query and key tensors
|
||||
"""
|
||||
seq_len = q.shape[2]
|
||||
|
||||
# Update cache if needed
|
||||
if seq_len > self.cached_seq_len:
|
||||
self._update_cos_sin_cache(seq_len)
|
||||
|
||||
# Get cos/sin for current positions
|
||||
if position_ids is not None:
|
||||
# For generation with KV-cache
|
||||
cos = self.cos_cached[position_ids].unsqueeze(1)
|
||||
sin = self.sin_cached[position_ids].unsqueeze(1)
|
||||
else:
|
||||
# For training or initial forward pass
|
||||
cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||
sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Apply rotation
|
||||
q_embed = (q * cos) + (self.rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (self.rotate_half(k) * sin)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class ALiBiPositionalBias(nn.Module):
|
||||
"""
|
||||
Attention with Linear Biases (ALiBi) from Press et al. (2021)
|
||||
https://arxiv.org/abs/2108.12409
|
||||
Alternative to RoPE
|
||||
"""
|
||||
|
||||
def __init__(self, num_heads: int, max_seq_len: int = 2048):
|
||||
"""
|
||||
Args:
|
||||
num_heads: Number of attention heads
|
||||
max_seq_len: Maximum sequence length
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
# Compute slopes for each head
|
||||
slopes = self._get_slopes(num_heads)
|
||||
self.register_buffer("slopes", slopes, persistent=False)
|
||||
|
||||
# Precompute bias matrix
|
||||
alibi = self._get_alibi_bias(max_seq_len, slopes)
|
||||
self.register_buffer("alibi_bias", alibi, persistent=False)
|
||||
|
||||
def _get_slopes(self, num_heads: int) -> torch.Tensor:
|
||||
"""Compute slopes for ALiBi"""
|
||||
def get_slopes_power_of_2(n):
|
||||
start = 2 ** (-(2 ** -(torch.log2(torch.tensor(n)) - 3)))
|
||||
ratio = start
|
||||
return torch.pow(2, torch.arange(n)) * ratio
|
||||
|
||||
# Handle non-power-of-2 number of heads
|
||||
if (num_heads & (num_heads - 1)) == 0:
|
||||
return get_slopes_power_of_2(num_heads)
|
||||
else:
|
||||
closest_power_of_2 = 2 ** torch.floor(torch.log2(torch.tensor(num_heads)))
|
||||
slopes_a = get_slopes_power_of_2(int(closest_power_of_2))
|
||||
slopes_b = self._get_slopes(int(2 * closest_power_of_2))[0::2][:num_heads - int(closest_power_of_2)]
|
||||
return torch.cat([slopes_a, slopes_b])
|
||||
|
||||
def _get_alibi_bias(self, seq_len: int, slopes: torch.Tensor) -> torch.Tensor:
|
||||
"""Precompute ALiBi bias matrix"""
|
||||
# Create relative position matrix
|
||||
pos = torch.arange(seq_len).unsqueeze(0)
|
||||
rel_pos = pos - pos.T # [seq_len, seq_len]
|
||||
|
||||
# Apply slopes [num_heads, seq_len, seq_len]
|
||||
alibi = rel_pos.unsqueeze(0) * slopes.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
return alibi
|
||||
|
||||
def forward(self, attention_scores: torch.Tensor, seq_len: int) -> torch.Tensor:
|
||||
"""
|
||||
Add ALiBi bias to attention scores
|
||||
|
||||
Args:
|
||||
attention_scores: [batch, num_heads, seq_len, seq_len]
|
||||
seq_len: Current sequence length
|
||||
|
||||
Returns:
|
||||
Biased attention scores
|
||||
"""
|
||||
return attention_scores + self.alibi_bias[:, :seq_len, :seq_len]
|
Reference in New Issue
Block a user