Added another learning source for Nora. Also added the requirements.

This commit is contained in:
2025-06-09 14:25:11 -04:00
parent da23742671
commit 5d53ba7cb8
14 changed files with 1070 additions and 78 deletions

View File

@ -7,9 +7,22 @@ No pretrained weights—everything is initialized randomly.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def top_k_logits(logits: torch.Tensor, k: int):
"""
Zero out all but the top k logits in each row; return modified logits.
logits: (vocab_size,)
"""
if k == 0:
return logits
topk_vals, _ = torch.topk(logits, k)
min_topk = topk_vals[-1]
return torch.where(logits < min_topk, torch.full_like(logits, -1e10), logits)
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 10_000):
super().__init__()
@ -98,3 +111,41 @@ class NoraTransformerLM(nn.Module):
x = x.permute(1, 0, 2) # (batch_size, seq_length, d_model)
logits = self.fc_out(x) # (batch_size, seq_length, vocab_size)
return logits
def generate(
self,
tokenizer,
device: str,
prompt: str,
max_length: int = 128,
temperature: float = 1.0,
top_k: int = 50,
) -> str:
"""
Autoregressively generate text from a prompt.
- tokenizer: CharTokenizer (for encode/decode)
- device: "cuda" or "cpu"
- prompt: initial string
- max_length: total tokens to generate (including prompt)
- temperature: scales logits before softmax
- top_k: keep only top_k logits at each step
"""
self.eval()
input_ids = tokenizer.encode(prompt)
input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
generated = input_ids.clone() # shape (1, seq_len)
for _ in range(max_length - input_ids.size(1)):
# 1) trim to last seq_length tokens if longer than context window
if generated.size(1) > self.pos_encoder.pe.size(1):
generated = generated[:, -self.pos_encoder.pe.size(1) :]
with torch.no_grad():
logits = self.forward(generated) # (1, seq_len, vocab_size)
next_token_logits = logits[0, -1, :] / temperature
filtered_logits = top_k_logits(next_token_logits, k=top_k)
probs = F.softmax(filtered_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1) # (1,)
generated = torch.cat([generated, next_id.unsqueeze(0)], dim=1)
output_ids = generated.squeeze(0).tolist()
return tokenizer.decode(output_ids)