87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
import matplotlib.pyplot as plt
|
|
import torch
|
|
from torch.nn.functional import softmax
|
|
|
|
loss_log = []
|
|
|
|
|
|
def track_loss(loss):
|
|
loss_log.append(loss)
|
|
if len(loss_log) % 50 == 0:
|
|
plot_loss()
|
|
|
|
|
|
def plot_loss():
|
|
plt.figure()
|
|
plt.plot(loss_log)
|
|
plt.title("Training Loss Over Time")
|
|
plt.xlabel("Steps")
|
|
plt.ylabel("Loss")
|
|
plt.savefig("loss_plot.png")
|
|
plt.close()
|
|
|
|
|
|
def update_model_vocab(model, tokenizer):
|
|
new_vocab = tokenizer.vocab_size()
|
|
old_vocab = model.token_emb.num_embeddings
|
|
d_model = model.token_emb.embedding_dim
|
|
|
|
if new_vocab > old_vocab:
|
|
# Resize token embedding
|
|
old_weights = model.token_emb.weight.data
|
|
new_emb = torch.nn.Embedding(new_vocab, d_model).to(old_weights.device)
|
|
new_emb.weight.data[:old_vocab] = old_weights
|
|
torch.nn.init.normal_(new_emb.weight.data[old_vocab:], mean=0.0, std=0.02)
|
|
model.token_emb = new_emb
|
|
|
|
# Resize output head
|
|
old_head = model.head
|
|
new_head = torch.nn.Linear(d_model, new_vocab).to(old_weights.device)
|
|
new_head.weight.data[:old_vocab] = old_head.weight.data
|
|
new_head.bias.data[:old_vocab] = old_head.bias.data
|
|
torch.nn.init.normal_(new_head.weight.data[old_vocab:], mean=0.0, std=0.02)
|
|
torch.nn.init.zeros_(new_head.bias.data[old_vocab:])
|
|
model.head = new_head
|
|
|
|
|
|
def sample_reply(model, tokenizer, input_ids, max_len=40):
|
|
model.eval()
|
|
generated = input_ids.clone()
|
|
device = input_ids.device
|
|
|
|
for _ in range(max_len):
|
|
# Truncate input to fit positional embedding
|
|
max_seq = model.pos_emb.size(1)
|
|
if generated.size(1) > max_seq:
|
|
generated = generated[:, -max_seq:]
|
|
|
|
update_model_vocab(model, tokenizer)
|
|
|
|
try:
|
|
logits = model(generated)
|
|
except RuntimeError as e:
|
|
print("CUDA crash in sample_reply — possible vocab mismatch")
|
|
print("Generated:", generated)
|
|
raise e
|
|
|
|
next_token_logits = logits[0, -1, :]
|
|
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
|
|
next_token = probs.argmax(dim=-1, keepdim=True)
|
|
next_token = next_token.unsqueeze(0) # Shape: [1, 1]
|
|
generated = torch.cat((generated, next_token), dim=1)
|
|
|
|
decoded = tokenizer.decode([next_token.item()])
|
|
if decoded in ['\n', '.', '!', '?']:
|
|
break
|
|
|
|
output = generated[0].tolist()[input_ids.shape[1]:]
|
|
reply = tokenizer.decode(output).strip()
|
|
print(f"[Reply] {repr(reply)}")
|
|
return reply
|
|
|
|
|
|
def sample_thought(model, tokenizer, device, context_text, max_len=60):
|
|
prompt = f"[thinking] {context_text}"
|
|
input_ids = tokenizer.encode(prompt, return_tensors=True).to(device)
|
|
return sample_reply(model, tokenizer, input_ids, max_len=max_len)
|