Added another learning source for Nora. Also added the requirements.
This commit is contained in:
51
model.py
51
model.py
@ -7,9 +7,22 @@ No pretrained weights—everything is initialized randomly.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
def top_k_logits(logits: torch.Tensor, k: int):
|
||||
"""
|
||||
Zero out all but the top k logits in each row; return modified logits.
|
||||
logits: (vocab_size,)
|
||||
"""
|
||||
if k == 0:
|
||||
return logits
|
||||
topk_vals, _ = torch.topk(logits, k)
|
||||
min_topk = topk_vals[-1]
|
||||
return torch.where(logits < min_topk, torch.full_like(logits, -1e10), logits)
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model: int, max_len: int = 10_000):
|
||||
super().__init__()
|
||||
@ -98,3 +111,41 @@ class NoraTransformerLM(nn.Module):
|
||||
x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model)
|
||||
logits = self.fc_out(x) # (batch_size, seq_length, vocab_size)
|
||||
return logits
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tokenizer,
|
||||
device: str,
|
||||
prompt: str,
|
||||
max_length: int = 128,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> str:
|
||||
"""
|
||||
Autoregressively generate text from a prompt.
|
||||
- tokenizer: CharTokenizer (for encode/decode)
|
||||
- device: "cuda" or "cpu"
|
||||
- prompt: initial string
|
||||
- max_length: total tokens to generate (including prompt)
|
||||
- temperature: scales logits before softmax
|
||||
- top_k: keep only top_k logits at each step
|
||||
"""
|
||||
self.eval()
|
||||
input_ids = tokenizer.encode(prompt)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
|
||||
generated = input_ids.clone() # shape (1, seq_len)
|
||||
for _ in range(max_length - input_ids.size(1)):
|
||||
# 1) trim to last seq_length tokens if longer than context window
|
||||
if generated.size(1) > self.pos_encoder.pe.size(1):
|
||||
generated = generated[:, -self.pos_encoder.pe.size(1) :]
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.forward(generated) # (1, seq_len, vocab_size)
|
||||
next_token_logits = logits[0, -1, :] / temperature
|
||||
filtered_logits = top_k_logits(next_token_logits, k=top_k)
|
||||
probs = F.softmax(filtered_logits, dim=-1)
|
||||
next_id = torch.multinomial(probs, num_samples=1) # (1,)
|
||||
generated = torch.cat([generated, next_id.unsqueeze(0)], dim=1)
|
||||
|
||||
output_ids = generated.squeeze(0).tolist()
|
||||
return tokenizer.decode(output_ids)
|
Reference in New Issue
Block a user