Nora/model.py

101 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
model.py
Defines a Transformerbased language model from scratch, using PyTorchs nn.Transformer.
No pretrained weights—everything is initialized randomly.
"""
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 10_000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # shape: (1, max_len, d_model)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (batch_size, seq_length, d_model)
returns x + positional encodings for the first seq_length positions.
"""
x = x + self.pe[:, : x.size(1), :]
return x
class NoraTransformerLM(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 512,
nhead: int = 8,
num_layers: int = 6,
dim_feedforward: int = 2048,
dropout: float = 0.1,
max_seq_len: int = 512,
):
super().__init__()
self.model_type = "TransformerLM"
self.d_model = d_model
self.vocab_size = vocab_size
# Token embedding + positional encoding
self.token_embed = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len)
# Transformer encoder layers
encoder_layers = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation="gelu",
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layers, num_layers=num_layers
)
# Final linear layer to project to vocabulary
self.fc_out = nn.Linear(d_model, vocab_size)
# Initialization
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.token_embed.weight, mean=0, std=self.d_model ** -0.5)
nn.init.zeros_(self.fc_out.bias)
nn.init.normal_(self.fc_out.weight, mean=0, std=self.d_model ** -0.5)
def forward(self, src: torch.Tensor) -> torch.Tensor:
"""
src: (batch_size, seq_length), token IDs
returns: logits (batch_size, seq_length, vocab_size)
"""
# Embed tokens and add positional encoding
x = self.token_embed(src) * math.sqrt(self.d_model) # (B, S, D)
x = self.pos_encoder(x) # (B, S, D)
# PyTorch Transformer expects (S, B, D)
x = x.permute(1, 0, 2) # (seq_length, batch_size, d_model)
# Create a causal mask so each position can only attend to previous positions
seq_len = x.size(0)
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
# Pass through Transformer encoder
x = self.transformer_encoder(x, mask=mask) # (seq_length, batch_size, d_model)
# Back to (B, S, D)
x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model)
logits = self.fc_out(x) # (batch_size, seq_length, vocab_size)
return logits