import matplotlib.pyplot as plt import torch from torch.nn.functional import softmax loss_log = [] def track_loss(loss): loss_log.append(loss) if len(loss_log) % 50 == 0: plot_loss() def plot_loss(): plt.figure() plt.plot(loss_log) plt.title("Training Loss Over Time") plt.xlabel("Steps") plt.ylabel("Loss") plt.savefig("loss_plot.png") plt.close() def update_model_vocab(model, tokenizer): new_vocab = tokenizer.vocab_size() old_vocab = model.token_emb.num_embeddings d_model = model.token_emb.embedding_dim if new_vocab > old_vocab: # Resize token embedding old_weights = model.token_emb.weight.data new_emb = torch.nn.Embedding(new_vocab, d_model).to(old_weights.device) new_emb.weight.data[:old_vocab] = old_weights torch.nn.init.normal_(new_emb.weight.data[old_vocab:], mean=0.0, std=0.02) model.token_emb = new_emb # Resize output head old_head = model.head new_head = torch.nn.Linear(d_model, new_vocab).to(old_weights.device) new_head.weight.data[:old_vocab] = old_head.weight.data new_head.bias.data[:old_vocab] = old_head.bias.data torch.nn.init.normal_(new_head.weight.data[old_vocab:], mean=0.0, std=0.02) torch.nn.init.zeros_(new_head.bias.data[old_vocab:]) model.head = new_head def sample_reply(model, tokenizer, input_ids, max_len=40): model.eval() generated = input_ids.clone() device = input_ids.device for _ in range(max_len): # Truncate input to fit positional embedding max_seq = model.pos_emb.size(1) if generated.size(1) > max_seq: generated = generated[:, -max_seq:] update_model_vocab(model, tokenizer) try: logits = model(generated) except RuntimeError as e: print("CUDA crash in sample_reply — possible vocab mismatch") print("Generated:", generated) raise e next_token_logits = logits[0, -1, :] probs = torch.nn.functional.softmax(next_token_logits, dim=-1) next_token = probs.argmax(dim=-1, keepdim=True) next_token = next_token.unsqueeze(0) # Shape: [1, 1] generated = torch.cat((generated, next_token), dim=1) decoded = tokenizer.decode([next_token.item()]) if decoded in ['\n', '.', '!', '?']: break output = generated[0].tolist()[input_ids.shape[1]:] reply = tokenizer.decode(output).strip() print(f"[Reply] {repr(reply)}") return reply def sample_thought(model, tokenizer, device, context_text, max_len=60): prompt = f"[thinking] {context_text}" input_ids = tokenizer.encode(prompt, return_tensors=True).to(device) return sample_reply(model, tokenizer, input_ids, max_len=max_len)