159 lines
7.2 KiB
Python
159 lines
7.2 KiB
Python
# Suggested Refinements for Jade (Model.py)
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import random
|
|
import string
|
|
import numpy as np
|
|
|
|
class JadeModel(nn.Module):
|
|
def __init__(self):
|
|
super(JadeModel, self).__init__()
|
|
# GPT-like Transformer architecture
|
|
self.vocab_size = 256 # Character-level tokenization (ASCII range)
|
|
self.embedding_dim = 768 # GPT-like embedding dimension
|
|
self.num_heads = 12 # Number of attention heads
|
|
self.num_layers = 12 # Number of transformer layers
|
|
self.max_position_embeddings = 512 # Maximum sequence length
|
|
|
|
# Embedding layers
|
|
self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
|
|
self.position_embedding = nn.Embedding(self.max_position_embeddings, self.embedding_dim)
|
|
|
|
# Transformer layers
|
|
self.transformer_layers = nn.ModuleList([
|
|
nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.num_heads)
|
|
for _ in range(self.num_layers)
|
|
])
|
|
|
|
# Output layer
|
|
self.fc = nn.Linear(self.embedding_dim, self.vocab_size)
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
# Optimizer and loss function
|
|
self.optimizer = optim.Adam(self.parameters(), lr=0.001)
|
|
self.criterion = nn.CrossEntropyLoss()
|
|
|
|
# Device setup
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.to(self.device)
|
|
|
|
# Debug message to verify changes (updated unique message for each change)
|
|
self.debug_message = "[DEBUG] Model initialized with version: Jade-Solstice-Horizon"
|
|
print(self.debug_message)
|
|
|
|
def forward(self, input_ids):
|
|
# Create position ids for input sequence
|
|
seq_length = input_ids.size(1)
|
|
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=self.device)
|
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
|
|
|
# Embedding lookup
|
|
x = self.embedding(input_ids) + self.position_embedding(position_ids)
|
|
|
|
# Pass through transformer layers
|
|
for layer in self.transformer_layers:
|
|
x = layer(x)
|
|
|
|
# Output layer
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
def generate_response(self, input_text, initial_temperature=0.85, top_p=0.8, repetition_penalty=1.4, max_token_frequency=2):
|
|
# Convert input_text to token ids
|
|
input_ids = self.tokenize(input_text)
|
|
input_tensor = torch.tensor(input_ids).unsqueeze(0).to(self.device)
|
|
generated_tokens = input_ids.copy()
|
|
recent_tokens = list(input_ids[-10:]) # Expanded recent tokens window to 10
|
|
temperature = initial_temperature
|
|
|
|
with torch.no_grad():
|
|
for i in range(50): # Generate up to 50 more tokens
|
|
output = self.forward(input_tensor)
|
|
logits = output[:, -1, :] # Consider only the last token's logits
|
|
logits = logits / (temperature + 1e-2) # Apply temperature for sampling diversity
|
|
|
|
# Apply repetition penalty
|
|
for token in set(generated_tokens):
|
|
if generated_tokens.count(token) > 1:
|
|
logits[0, token] /= (repetition_penalty + generated_tokens.count(token) * 0.02) # Frequency-based scaling for penalty
|
|
|
|
# Apply slight logits smoothing to avoid overly confident peaks
|
|
logits = logits - torch.mean(logits) * 0.01
|
|
|
|
# Dynamic Nucleus (top-p) sampling with adjusted threshold
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
cumulative_probs = torch.cumsum(self.softmax(sorted_logits), dim=-1)
|
|
top_p_mask = cumulative_probs < top_p
|
|
top_p_logits = sorted_logits[top_p_mask]
|
|
top_p_indices = sorted_indices[top_p_mask]
|
|
|
|
if len(top_p_logits) > 1:
|
|
top_p_probs = self.softmax(top_p_logits)
|
|
sampled_token = top_p_indices[torch.multinomial(top_p_probs, num_samples=1).item()].item()
|
|
else:
|
|
sampled_token = sorted_indices[0, 0].item() # Fallback to the most probable token if none pass the top-p threshold
|
|
|
|
# Enforce diversity constraint by limiting token frequency
|
|
if generated_tokens.count(sampled_token) >= max_token_frequency:
|
|
logits[0, sampled_token] -= 1.5 # Adjusted penalty to limit token frequency
|
|
continue # Skip adding this token if it has reached the max frequency
|
|
|
|
# Stop repetition if the sampled token was recently repeated
|
|
if len(generated_tokens) > 1 and generated_tokens[-1] == sampled_token:
|
|
continue
|
|
|
|
# Add token and update state
|
|
generated_tokens.append(sampled_token)
|
|
recent_tokens.append(sampled_token)
|
|
if len(recent_tokens) > 10:
|
|
recent_tokens.pop(0) # Maintain a window of recent tokens to suppress
|
|
|
|
# Update input tensor to include the generated token
|
|
input_tensor = torch.tensor(generated_tokens).unsqueeze(0).to(self.device)
|
|
|
|
# Gradually decrease temperature to reduce randomness more smoothly
|
|
temperature = max(0.75, temperature * 0.98)
|
|
|
|
response = self.detokenize(generated_tokens)
|
|
print("[DEBUG] Generated response:", response) # Debug statement to verify changes
|
|
print(f"[DEBUG] Generation loss rate (approximated): {temperature}") # Approximate loss rate
|
|
return response
|
|
|
|
def tokenize(self, text):
|
|
# Character-level tokenizer: converts text to ASCII values
|
|
token_ids = [ord(char) for char in text if ord(char) < self.vocab_size]
|
|
return token_ids
|
|
|
|
def detokenize(self, token_ids):
|
|
# Detokenizer to convert ASCII values back to characters
|
|
return "".join([chr(id) for id in token_ids])
|
|
|
|
def train_on_message(self, message):
|
|
# Tokenize the message
|
|
input_ids = self.tokenize(message)
|
|
input_tensor = torch.tensor(input_ids).unsqueeze(0).to(self.device)
|
|
|
|
# Create target labels (next character prediction task)
|
|
labels = input_ids[1:] + [input_ids[-1]] # Shift tokens for training
|
|
labels_tensor = torch.tensor(labels).unsqueeze(0).to(self.device)
|
|
|
|
# Training step
|
|
self.optimizer.zero_grad()
|
|
outputs = self.forward(input_tensor)
|
|
loss = self.criterion(outputs.view(-1, outputs.size(-1)), labels_tensor.view(-1))
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
print(f"Training loss: {loss.item()}")
|
|
|
|
# Changes made:
|
|
# Version: Jade-Solstice-Horizon
|
|
# - Reverted temperature, top_p, and repetition penalty settings to be closer to Solstice.
|
|
# - Introduced explicit stop criteria to prevent repeating tokens consecutively.
|
|
# - Applied slight smoothing to logits to prevent high peaks and excessive repetition.
|
|
# - Updated debug message to reflect the new version.
|
|
|
|
# Observations:
|
|
# - Aimed to retain the strengths of Solstice while reducing remaining issues with repetitive tokens by adding specific repetition stop criteria.
|