RubyOld/utils.py
2025-04-08 19:52:01 -04:00

87 lines
2.8 KiB
Python

import matplotlib.pyplot as plt
import torch
from torch.nn.functional import softmax
loss_log = []
def track_loss(loss):
loss_log.append(loss)
if len(loss_log) % 50 == 0:
plot_loss()
def plot_loss():
plt.figure()
plt.plot(loss_log)
plt.title("Training Loss Over Time")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.savefig("loss_plot.png")
plt.close()
def update_model_vocab(model, tokenizer):
new_vocab = tokenizer.vocab_size()
old_vocab = model.token_emb.num_embeddings
d_model = model.token_emb.embedding_dim
if new_vocab > old_vocab:
# Resize token embedding
old_weights = model.token_emb.weight.data
new_emb = torch.nn.Embedding(new_vocab, d_model).to(old_weights.device)
new_emb.weight.data[:old_vocab] = old_weights
torch.nn.init.normal_(new_emb.weight.data[old_vocab:], mean=0.0, std=0.02)
model.token_emb = new_emb
# Resize output head
old_head = model.head
new_head = torch.nn.Linear(d_model, new_vocab).to(old_weights.device)
new_head.weight.data[:old_vocab] = old_head.weight.data
new_head.bias.data[:old_vocab] = old_head.bias.data
torch.nn.init.normal_(new_head.weight.data[old_vocab:], mean=0.0, std=0.02)
torch.nn.init.zeros_(new_head.bias.data[old_vocab:])
model.head = new_head
def sample_reply(model, tokenizer, input_ids, max_len=40):
model.eval()
generated = input_ids.clone()
device = input_ids.device
for _ in range(max_len):
# Truncate input to fit positional embedding
max_seq = model.pos_emb.size(1)
if generated.size(1) > max_seq:
generated = generated[:, -max_seq:]
update_model_vocab(model, tokenizer)
try:
logits = model(generated)
except RuntimeError as e:
print("CUDA crash in sample_reply — possible vocab mismatch")
print("Generated:", generated)
raise e
next_token_logits = logits[0, -1, :]
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
next_token = probs.argmax(dim=-1, keepdim=True)
next_token = next_token.unsqueeze(0) # Shape: [1, 1]
generated = torch.cat((generated, next_token), dim=1)
decoded = tokenizer.decode([next_token.item()])
if decoded in ['\n', '.', '!', '?']:
break
output = generated[0].tolist()[input_ids.shape[1]:]
reply = tokenizer.decode(output).strip()
print(f"[Reply] {repr(reply)}")
return reply
def sample_thought(model, tokenizer, device, context_text, max_len=60):
prompt = f"[thinking] {context_text}"
input_ids = tokenizer.encode(prompt, return_tensors=True).to(device)
return sample_reply(model, tokenizer, input_ids, max_len=max_len)