Files
Nora/model.py

151 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
model.py
Defines a Transformerbased language model from scratch, using PyTorchs nn.Transformer.
No pretrained weights—everything is initialized randomly.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def top_k_logits(logits: torch.Tensor, k: int):
"""
Zero out all but the top k logits in each row; return modified logits.
logits: (vocab_size,)
"""
if k == 0:
return logits
topk_vals, _ = torch.topk(logits, k)
min_topk = topk_vals[-1]
return torch.where(logits < min_topk, torch.full_like(logits, -1e10), logits)
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 10_000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # shape: (1, max_len, d_model)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (batch_size, seq_length, d_model)
returns x + positional encodings for the first seq_length positions.
"""
x = x + self.pe[:, : x.size(1), :]
return x
class NoraTransformerLM(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 512,
nhead: int = 8,
num_layers: int = 6,
dim_feedforward: int = 2048,
dropout: float = 0.1,
max_seq_len: int = 512,
):
super().__init__()
self.model_type = "TransformerLM"
self.d_model = d_model
self.vocab_size = vocab_size
# Token embedding + positional encoding
self.token_embed = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len)
# Transformer encoder layers
encoder_layers = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation="gelu",
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layers, num_layers=num_layers
)
# Final linear layer to project to vocabulary
self.fc_out = nn.Linear(d_model, vocab_size)
# Initialization
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.token_embed.weight, mean=0, std=self.d_model ** -0.5)
nn.init.zeros_(self.fc_out.bias)
nn.init.normal_(self.fc_out.weight, mean=0, std=self.d_model ** -0.5)
def forward(self, src: torch.Tensor) -> torch.Tensor:
"""
src: (batch_size, seq_length), token IDs
returns: logits (batch_size, seq_length, vocab_size)
"""
# Embed tokens and add positional encoding
x = self.token_embed(src) * math.sqrt(self.d_model) # (B, S, D)
x = self.pos_encoder(x) # (B, S, D)
# PyTorch Transformer expects (S, B, D)
x = x.permute(1, 0, 2) # (seq_length, batch_size, d_model)
# Create a causal mask so each position can only attend to previous positions
seq_len = x.size(0)
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
# Pass through Transformer encoder
x = self.transformer_encoder(x, mask=mask) # (seq_length, batch_size, d_model)
# Back to (B, S, D)
x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model)
logits = self.fc_out(x) # (batch_size, seq_length, vocab_size)
return logits
def generate(
self,
tokenizer,
device: str,
prompt: str,
max_length: int = 128,
temperature: float = 1.0,
top_k: int = 50,
) -> str:
"""
Autoregressively generate text from a prompt.
- tokenizer: CharTokenizer (for encode/decode)
- device: "cuda" or "cpu"
- prompt: initial string
- max_length: total tokens to generate (including prompt)
- temperature: scales logits before softmax
- top_k: keep only top_k logits at each step
"""
self.eval()
input_ids = tokenizer.encode(prompt)
input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
generated = input_ids.clone() # shape (1, seq_len)
for _ in range(max_length - input_ids.size(1)):
# 1) trim to last seq_length tokens if longer than context window
if generated.size(1) > self.pos_encoder.pe.size(1):
generated = generated[:, -self.pos_encoder.pe.size(1) :]
with torch.no_grad():
logits = self.forward(generated) # (1, seq_len, vocab_size)
next_token_logits = logits[0, -1, :] / temperature
filtered_logits = top_k_logits(next_token_logits, k=top_k)
probs = F.softmax(filtered_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1) # (1,)
generated = torch.cat([generated, next_id.unsqueeze(0)], dim=1)
output_ids = generated.squeeze(0).tolist()
return tokenizer.decode(output_ids)