""" NOVA Transformer - Main model implementation """ import torch import torch.nn as nn from typing import Optional, Tuple, List import math from .config import ModelConfig from .layers import TransformerBlock from .rope import RotaryPositionalEmbedding, ALiBiPositionalBias from .normalization import get_norm_layer from .attention import create_causal_mask class NovaTransformer(nn.Module): """ NOVA Transformer Language Model A decoder-only transformer with: - RoPE or ALiBi positional encoding - RMSNorm or LayerNorm - SwiGLU or GELU activations - Grouped-query attention (optional) - KV-cache for fast inference - Gradient checkpointing support """ def __init__(self, config: ModelConfig): super().__init__() self.config = config self.vocab_size = config.vocab_size self.hidden_size = config.hidden_size # Token embeddings self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # Positional encoding if config.use_rope: self.rope = RotaryPositionalEmbedding( dim=config.hidden_size // config.num_attention_heads, max_seq_len=config.max_position_embeddings, theta=config.rope_theta ) elif config.use_alibi: self.alibi = ALiBiPositionalBias( num_heads=config.num_attention_heads, max_seq_len=config.max_position_embeddings ) else: self.rope = None self.alibi = None # Transformer blocks self.layers = nn.ModuleList([ TransformerBlock(config, layer_idx=i) for i in range(config.num_hidden_layers) ]) # Final layer norm self.norm = get_norm_layer( config.norm_type, config.hidden_size, config.rms_norm_eps ) # Language model head self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Tie weights if specified if config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight # Gradient checkpointing self.gradient_checkpointing = config.gradient_checkpointing # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): """Initialize weights using normal distribution""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def _prepare_decoder_attention_mask( self, input_ids: torch.Tensor, past_key_values_length: int = 0 ) -> torch.Tensor: """ Create causal attention mask for decoder Args: input_ids: [batch, seq_len] past_key_values_length: Length of cached keys/values Returns: Causal attention mask """ batch_size, seq_len = input_ids.shape device = input_ids.device dtype = torch.float32 # Create causal mask if past_key_values_length > 0: # During generation, only mask the new token mask = torch.zeros( (batch_size, 1, seq_len, past_key_values_length + seq_len), device=device, dtype=dtype ) else: # During training, mask future tokens mask = create_causal_mask(seq_len, device, dtype) return mask def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, return_dict: bool = True, ): """ Forward pass through NOVA transformer Args: input_ids: [batch, seq_len] attention_mask: Optional custom attention mask past_key_values: Optional cached key/values for generation use_cache: Whether to return key/value cache return_dict: Whether to return dict or tuple Returns: ModelOutput with logits and optional cache """ batch_size, seq_len = input_ids.shape # Get past sequence length for KV-cache past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] # Embed tokens hidden_states = self.embed_tokens(input_ids) # Prepare attention mask if attention_mask is None: attention_mask = self._prepare_decoder_attention_mask( input_ids, past_key_values_length ) # Prepare position embeddings for RoPE position_embeddings = None if self.rope is not None: # Create position IDs position_ids = torch.arange( past_key_values_length, seq_len + past_key_values_length, dtype=torch.long, device=input_ids.device ) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) # Get cos/sin embeddings cos = self.rope.cos_cached[position_ids].unsqueeze(1) sin = self.rope.sin_cached[position_ids].unsqueeze(1) position_embeddings = (cos, sin) # Pass through transformer blocks next_cache = [] if use_cache else None for idx, layer in enumerate(self.layers): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: # Use gradient checkpointing during training def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer), hidden_states, attention_mask, position_embeddings, past_key_value, use_cache, ) else: layer_outputs = layer( hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, past_key_value=past_key_value, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_cache.append(layer_outputs[1]) # Final layer norm hidden_states = self.norm(hidden_states) # LM head logits = self.lm_head(hidden_states) if return_dict: return { 'logits': logits, 'past_key_values': next_cache if use_cache else None, 'hidden_states': hidden_states, } else: return (logits, next_cache if use_cache else None) @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, repetition_penalty: float = 1.0, do_sample: bool = True, eos_token_id: Optional[int] = None, ) -> torch.Tensor: """ Generate text using the model Args: input_ids: [batch, seq_len] starting tokens max_new_tokens: Maximum tokens to generate temperature: Sampling temperature (higher = more random) top_k: Keep only top k tokens for sampling top_p: Nucleus sampling - keep top tokens with cumulative probability p repetition_penalty: Penalty for repeating tokens (>1.0 discourages) do_sample: Whether to sample (True) or use greedy decoding (False) eos_token_id: Token ID that ends generation Returns: Generated token IDs [batch, seq_len + new_tokens] """ self.eval() device = input_ids.device past_key_values = None for _ in range(max_new_tokens): # Forward pass with cache outputs = self.forward( input_ids=input_ids if past_key_values is None else input_ids[:, -1:], past_key_values=past_key_values, use_cache=True, ) logits = outputs['logits'][:, -1, :] # [batch, vocab_size] past_key_values = outputs['past_key_values'] # Apply repetition penalty if repetition_penalty != 1.0: for token_id in set(input_ids[0].tolist()): logits[0, token_id] /= repetition_penalty # Apply temperature if temperature != 1.0: logits = logits / temperature # Top-k filtering if top_k is not None: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = float('-inf') # Top-p (nucleus) filtering if top_p is not None: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) logits[indices_to_remove] = float('-inf') # Sample or greedy decode if do_sample: probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(logits, dim=-1, keepdim=True) # Append to sequence input_ids = torch.cat([input_ids, next_token], dim=-1) # Check for EOS if eos_token_id is not None and next_token.item() == eos_token_id: break return input_ids def get_num_params(self, non_embedding: bool = False) -> int: """ Get number of parameters in the model Args: non_embedding: If True, exclude embedding parameters Returns: Number of parameters """ n_params = sum(p.numel() for p in self.parameters()) if non_embedding: n_params -= self.embed_tokens.weight.numel() return n_params