151 lines
5.3 KiB
Python
151 lines
5.3 KiB
Python
"""
|
||
model.py
|
||
|
||
Defines a Transformer‐based language model from scratch, using PyTorch’s nn.Transformer.
|
||
No pretrained weights—everything is initialized randomly.
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import math
|
||
|
||
|
||
def top_k_logits(logits: torch.Tensor, k: int):
|
||
"""
|
||
Zero out all but the top k logits in each row; return modified logits.
|
||
logits: (vocab_size,)
|
||
"""
|
||
if k == 0:
|
||
return logits
|
||
topk_vals, _ = torch.topk(logits, k)
|
||
min_topk = topk_vals[-1]
|
||
return torch.where(logits < min_topk, torch.full_like(logits, -1e10), logits)
|
||
|
||
|
||
class PositionalEncoding(nn.Module):
|
||
def __init__(self, d_model: int, max_len: int = 10_000):
|
||
super().__init__()
|
||
pe = torch.zeros(max_len, d_model)
|
||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||
div_term = torch.exp(
|
||
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
||
)
|
||
pe[:, 0::2] = torch.sin(position * div_term)
|
||
pe[:, 1::2] = torch.cos(position * div_term)
|
||
pe = pe.unsqueeze(0) # shape: (1, max_len, d_model)
|
||
self.register_buffer("pe", pe)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
x: (batch_size, seq_length, d_model)
|
||
returns x + positional encodings for the first seq_length positions.
|
||
"""
|
||
x = x + self.pe[:, : x.size(1), :]
|
||
return x
|
||
|
||
|
||
class NoraTransformerLM(nn.Module):
|
||
def __init__(
|
||
self,
|
||
vocab_size: int,
|
||
d_model: int = 512,
|
||
nhead: int = 8,
|
||
num_layers: int = 6,
|
||
dim_feedforward: int = 2048,
|
||
dropout: float = 0.1,
|
||
max_seq_len: int = 512,
|
||
):
|
||
super().__init__()
|
||
self.model_type = "TransformerLM"
|
||
self.d_model = d_model
|
||
self.vocab_size = vocab_size
|
||
|
||
# Token embedding + positional encoding
|
||
self.token_embed = nn.Embedding(vocab_size, d_model)
|
||
self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len)
|
||
|
||
# Transformer encoder layers
|
||
encoder_layers = nn.TransformerEncoderLayer(
|
||
d_model=d_model,
|
||
nhead=nhead,
|
||
dim_feedforward=dim_feedforward,
|
||
dropout=dropout,
|
||
activation="gelu",
|
||
)
|
||
self.transformer_encoder = nn.TransformerEncoder(
|
||
encoder_layers, num_layers=num_layers
|
||
)
|
||
|
||
# Final linear layer to project to vocabulary
|
||
self.fc_out = nn.Linear(d_model, vocab_size)
|
||
|
||
# Initialization
|
||
self._init_weights()
|
||
|
||
def _init_weights(self):
|
||
nn.init.normal_(self.token_embed.weight, mean=0, std=self.d_model ** -0.5)
|
||
nn.init.zeros_(self.fc_out.bias)
|
||
nn.init.normal_(self.fc_out.weight, mean=0, std=self.d_model ** -0.5)
|
||
|
||
def forward(self, src: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
src: (batch_size, seq_length), token IDs
|
||
returns: logits (batch_size, seq_length, vocab_size)
|
||
"""
|
||
|
||
# Embed tokens and add positional encoding
|
||
x = self.token_embed(src) * math.sqrt(self.d_model) # (B, S, D)
|
||
x = self.pos_encoder(x) # (B, S, D)
|
||
# PyTorch Transformer expects (S, B, D)
|
||
x = x.permute(1, 0, 2) # (seq_length, batch_size, d_model)
|
||
|
||
# Create a causal mask so each position can only attend to previous positions
|
||
seq_len = x.size(0)
|
||
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
|
||
|
||
# Pass through Transformer encoder
|
||
x = self.transformer_encoder(x, mask=mask) # (seq_length, batch_size, d_model)
|
||
|
||
# Back to (B, S, D)
|
||
x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model)
|
||
logits = self.fc_out(x) # (batch_size, seq_length, vocab_size)
|
||
return logits
|
||
|
||
def generate(
|
||
self,
|
||
tokenizer,
|
||
device: str,
|
||
prompt: str,
|
||
max_length: int = 128,
|
||
temperature: float = 1.0,
|
||
top_k: int = 50,
|
||
) -> str:
|
||
"""
|
||
Autoregressively generate text from a prompt.
|
||
- tokenizer: CharTokenizer (for encode/decode)
|
||
- device: "cuda" or "cpu"
|
||
- prompt: initial string
|
||
- max_length: total tokens to generate (including prompt)
|
||
- temperature: scales logits before softmax
|
||
- top_k: keep only top_k logits at each step
|
||
"""
|
||
self.eval()
|
||
input_ids = tokenizer.encode(prompt)
|
||
input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
|
||
generated = input_ids.clone() # shape (1, seq_len)
|
||
for _ in range(max_length - input_ids.size(1)):
|
||
# 1) trim to last seq_length tokens if longer than context window
|
||
if generated.size(1) > self.pos_encoder.pe.size(1):
|
||
generated = generated[:, -self.pos_encoder.pe.size(1) :]
|
||
|
||
with torch.no_grad():
|
||
logits = self.forward(generated) # (1, seq_len, vocab_size)
|
||
next_token_logits = logits[0, -1, :] / temperature
|
||
filtered_logits = top_k_logits(next_token_logits, k=top_k)
|
||
probs = F.softmax(filtered_logits, dim=-1)
|
||
next_id = torch.multinomial(probs, num_samples=1) # (1,)
|
||
generated = torch.cat([generated, next_id.unsqueeze(0)], dim=1)
|
||
|
||
output_ids = generated.squeeze(0).tolist()
|
||
return tokenizer.decode(output_ids) |