Got Jade to exactly copy without extra characters - Version Solstice-Horizon
This commit is contained in:
parent
d7f116620a
commit
e0ea105872
145
main.py
145
main.py
@ -1,123 +1,40 @@
|
|||||||
|
# main.py: Discord Bot Code
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import torch
|
import torch
|
||||||
from model import SimpleTokenizer, initialize_model, train_on_conversation, save_model, update_model_vocab
|
from model import JadeModel
|
||||||
import torch.nn.functional as F
|
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
class DiscordBot(discord.Client):
|
|
||||||
def __init__(self, **options):
|
|
||||||
super().__init__(**options)
|
|
||||||
self.tokenizer = SimpleTokenizer()
|
|
||||||
self.tokenizer_vocab_path = 'tokenizer_vocab.json'
|
|
||||||
self.tokenizer.load_vocab(self.tokenizer_vocab_path)
|
|
||||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
||||||
self.model, self.optimizer, self.criterion = initialize_model(self.tokenizer, self.device)
|
|
||||||
self.conversation_history = [] # Keep track of conversations for learning
|
|
||||||
self.previous_reply = None # Store last reply for pattern recognition
|
|
||||||
|
|
||||||
async def on_ready(self):
|
|
||||||
print(f'Logged in as {self.user.name}')
|
|
||||||
|
|
||||||
async def on_message(self, message):
|
|
||||||
if message.author == self.user:
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"Received message from {message.author}: {message.content}")
|
|
||||||
|
|
||||||
# Update tokenizer vocabulary with the new message
|
|
||||||
previous_vocab_size = len(self.tokenizer.token2idx)
|
|
||||||
self.tokenizer.build_vocab([message.content])
|
|
||||||
new_vocab_size = len(self.tokenizer.token2idx)
|
|
||||||
|
|
||||||
# Update model if vocabulary has changed
|
|
||||||
if new_vocab_size != previous_vocab_size:
|
|
||||||
self.model = update_model_vocab(self.model, self.tokenizer, self.device)
|
|
||||||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
|
||||||
print("Model vocabulary updated.")
|
|
||||||
|
|
||||||
# Generate a reply
|
|
||||||
self.model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
reply = self.generate_reply(message.content)
|
|
||||||
print(f"Sending reply: {reply}")
|
|
||||||
await message.channel.send(reply)
|
|
||||||
|
|
||||||
# Append conversation to history for future learning
|
|
||||||
self.conversation_history.append({
|
|
||||||
"user_message": message.content,
|
|
||||||
"bot_reply": reply,
|
|
||||||
"channel": message.channel
|
|
||||||
})
|
|
||||||
|
|
||||||
# Continuous learning: Train on this conversation pair
|
|
||||||
loss = train_on_conversation(
|
|
||||||
self.model,
|
|
||||||
self.optimizer,
|
|
||||||
self.criterion,
|
|
||||||
self.tokenizer,
|
|
||||||
message.content,
|
|
||||||
reply,
|
|
||||||
self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save the model and tokenizer for future sessions
|
|
||||||
save_model(self.model)
|
|
||||||
self.tokenizer.save_vocab(self.tokenizer_vocab_path)
|
|
||||||
|
|
||||||
# Store this reply to help Jade learn from repetition in future responses
|
|
||||||
self.previous_reply = reply
|
|
||||||
|
|
||||||
def generate_reply(self, input_text, max_length=20, temperature=1.0, top_k=10):
|
|
||||||
# Prepare the input sequence with special tokens
|
|
||||||
input_sequence = ['<SOS>'] + input_text.split() + ['<EOS>']
|
|
||||||
input_indices = self.tokenizer.encode(' '.join(input_sequence))
|
|
||||||
input_tensor = torch.tensor([input_indices], dtype=torch.long, device=self.device)
|
|
||||||
|
|
||||||
generated_indices = []
|
|
||||||
for _ in range(max_length):
|
|
||||||
output = self.model(input_tensor)
|
|
||||||
if output.size(0) == 0:
|
|
||||||
print("Model output is empty. Breaking out of generation loop.")
|
|
||||||
break
|
|
||||||
next_token_logits = output[-1, 0, :] / temperature
|
|
||||||
|
|
||||||
# Penalize <UNK>
|
|
||||||
unk_token_idx = self.tokenizer.token2idx.get('<UNK>', None)
|
|
||||||
if unk_token_idx is not None:
|
|
||||||
next_token_logits[unk_token_idx] = -float('inf')
|
|
||||||
|
|
||||||
# Apply Top-K sampling
|
|
||||||
top_k = min(top_k, next_token_logits.size(-1))
|
|
||||||
values, indices = torch.topk(next_token_logits, top_k)
|
|
||||||
probabilities = F.softmax(values, dim=-1)
|
|
||||||
predicted_index = indices[torch.multinomial(probabilities, 1)].item()
|
|
||||||
|
|
||||||
# Stop if <EOS> token is generated
|
|
||||||
if predicted_index == self.tokenizer.token2idx.get('<EOS>'):
|
|
||||||
break
|
|
||||||
|
|
||||||
generated_indices.append(predicted_index)
|
|
||||||
input_indices.append(predicted_index)
|
|
||||||
input_tensor = torch.tensor([input_indices], dtype=torch.long, device=self.device)
|
|
||||||
|
|
||||||
# Filter out special tokens from generated indices
|
|
||||||
special_token_indices = set(self.tokenizer.token2idx[token] for token in ['<PAD>', '<UNK>', '<SOS>', '<EOS>'])
|
|
||||||
filtered_indices = [idx for idx in generated_indices if idx not in special_token_indices]
|
|
||||||
|
|
||||||
# Decode the filtered indices
|
|
||||||
reply = self.tokenizer.decode(filtered_indices)
|
|
||||||
return reply
|
|
||||||
|
|
||||||
|
|
||||||
DISCORD_TOKEN = os.getenv('DISCORD_TOKEN')
|
|
||||||
|
|
||||||
# Initialize and run the Discord bot
|
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
|
intents.messages = True
|
||||||
intents.message_content = True
|
intents.message_content = True
|
||||||
bot = DiscordBot(intents=intents)
|
client = discord.Client(intents=intents)
|
||||||
bot.run(DISCORD_TOKEN)
|
|
||||||
|
# Initialize the model
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = JadeModel().to(device)
|
||||||
|
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
print(f'We have logged in as {client.user}')
|
||||||
|
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_message(message):
|
||||||
|
if message.author == client.user:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Train Jade with the new message
|
||||||
|
model.train_on_message(message.content)
|
||||||
|
|
||||||
|
# Generate a response using Jade
|
||||||
|
response = model.generate_response(message.content)
|
||||||
|
await message.channel.send(response)
|
||||||
|
|
||||||
|
# Start the bot with your token
|
||||||
|
client.run(os.getenv('DISCORD_TOKEN'))
|
||||||
|
277
model.py
277
model.py
@ -1,183 +1,158 @@
|
|||||||
|
# Suggested Refinements for Jade (Model.py)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import threading
|
import torch.optim as optim
|
||||||
import os
|
import random
|
||||||
import json
|
import string
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class JadeModel(nn.Module):
|
||||||
# Simple Tokenizer
|
|
||||||
class SimpleTokenizer:
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.token2idx = {'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3}
|
super(JadeModel, self).__init__()
|
||||||
self.idx2token = {idx: token for token, idx in self.token2idx.items()}
|
# GPT-like Transformer architecture
|
||||||
self.lock = threading.Lock()
|
self.vocab_size = 256 # Character-level tokenization (ASCII range)
|
||||||
|
self.embedding_dim = 768 # GPT-like embedding dimension
|
||||||
|
self.num_heads = 12 # Number of attention heads
|
||||||
|
self.num_layers = 12 # Number of transformer layers
|
||||||
|
self.max_position_embeddings = 512 # Maximum sequence length
|
||||||
|
|
||||||
def build_vocab(self, texts):
|
# Embedding layers
|
||||||
with self.lock:
|
self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
|
||||||
for text in texts:
|
self.position_embedding = nn.Embedding(self.max_position_embeddings, self.embedding_dim)
|
||||||
tokens = text.split()
|
|
||||||
for token in tokens:
|
|
||||||
if token not in self.token2idx:
|
|
||||||
idx = len(self.token2idx)
|
|
||||||
self.token2idx[token] = idx
|
|
||||||
self.idx2token[idx] = token
|
|
||||||
|
|
||||||
def encode(self, text):
|
# Transformer layers
|
||||||
with self.lock:
|
self.transformer_layers = nn.ModuleList([
|
||||||
return [self.token2idx.get(token, self.token2idx['<UNK>']) for token in text.split()]
|
nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.num_heads)
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
def decode(self, indices):
|
# Output layer
|
||||||
with self.lock:
|
self.fc = nn.Linear(self.embedding_dim, self.vocab_size)
|
||||||
return ' '.join([self.idx2token.get(idx, '<UNK>') for idx in indices])
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
def save_vocab(self, path):
|
# Optimizer and loss function
|
||||||
with open(path, 'w') as f:
|
self.optimizer = optim.Adam(self.parameters(), lr=0.001)
|
||||||
json.dump({'token2idx': self.token2idx, 'idx2token': self.idx2token}, f)
|
self.criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
def load_vocab(self, path):
|
# Device setup
|
||||||
if os.path.exists(path):
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
with open(path, 'r') as f:
|
self.to(self.device)
|
||||||
vocab = json.load(f)
|
|
||||||
self.token2idx = vocab['token2idx']
|
|
||||||
self.idx2token = {int(k): v for k, v in vocab['idx2token'].items()}
|
|
||||||
print('Tokenizer vocabulary loaded from', path)
|
|
||||||
# Ensure special tokens are present
|
|
||||||
special_tokens = {'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3}
|
|
||||||
for token, idx in special_tokens.items():
|
|
||||||
if token not in self.token2idx:
|
|
||||||
self.token2idx[token] = idx
|
|
||||||
self.idx2token[idx] = token
|
|
||||||
else:
|
|
||||||
print('No existing tokenizer vocabulary found. Starting fresh.')
|
|
||||||
self.token2idx = {'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3}
|
|
||||||
self.idx2token = {idx: token for token, idx in self.token2idx.items()}
|
|
||||||
|
|
||||||
|
# Debug message to verify changes (updated unique message for each change)
|
||||||
|
self.debug_message = "[DEBUG] Model initialized with version: Jade-Solstice-Horizon"
|
||||||
|
print(self.debug_message)
|
||||||
|
|
||||||
# Positional Encoding
|
def forward(self, input_ids):
|
||||||
class PositionalEncoding(nn.Module):
|
# Create position ids for input sequence
|
||||||
def __init__(self, d_model, max_len=5000):
|
seq_length = input_ids.size(1)
|
||||||
super(PositionalEncoding, self).__init__()
|
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=self.device)
|
||||||
pe = torch.zeros(max_len, d_model)
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||||
position = torch.arange(0, max_len).unsqueeze(1).float()
|
|
||||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
|
|
||||||
pe[:, 0::2] = torch.sin(position * div_term)
|
|
||||||
if d_model % 2 == 1:
|
|
||||||
pe[:, -1] = torch.cos(position.squeeze() * div_term[-1])
|
|
||||||
else:
|
|
||||||
pe[:, 1::2] = torch.cos(position * div_term)
|
|
||||||
pe = pe.unsqueeze(1) # Shape: [max_len, 1, d_model]
|
|
||||||
self.register_buffer('pe', pe)
|
|
||||||
|
|
||||||
def forward(self, x):
|
# Embedding lookup
|
||||||
# x: [seq_len, batch_size, d_model]
|
x = self.embedding(input_ids) + self.position_embedding(position_ids)
|
||||||
x = x + self.pe[:x.size(0)]
|
|
||||||
|
# Pass through transformer layers
|
||||||
|
for layer in self.transformer_layers:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
# Output layer
|
||||||
|
x = self.fc(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def generate_response(self, input_text, initial_temperature=0.85, top_p=0.8, repetition_penalty=1.4, max_token_frequency=2):
|
||||||
|
# Convert input_text to token ids
|
||||||
|
input_ids = self.tokenize(input_text)
|
||||||
|
input_tensor = torch.tensor(input_ids).unsqueeze(0).to(self.device)
|
||||||
|
generated_tokens = input_ids.copy()
|
||||||
|
recent_tokens = list(input_ids[-10:]) # Expanded recent tokens window to 10
|
||||||
|
temperature = initial_temperature
|
||||||
|
|
||||||
# GPT Model
|
with torch.no_grad():
|
||||||
class GPTModel(nn.Module):
|
for i in range(50): # Generate up to 50 more tokens
|
||||||
def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=2):
|
output = self.forward(input_tensor)
|
||||||
super(GPTModel, self).__init__()
|
logits = output[:, -1, :] # Consider only the last token's logits
|
||||||
self.model_type = 'Transformer'
|
logits = logits / (temperature + 1e-2) # Apply temperature for sampling diversity
|
||||||
self.d_model = d_model
|
|
||||||
self.embedding = nn.Embedding(vocab_size, d_model)
|
|
||||||
self.pos_encoder = PositionalEncoding(d_model)
|
|
||||||
encoder_layers = nn.TransformerEncoderLayer(d_model, nhead)
|
|
||||||
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
|
|
||||||
self.fc_out = nn.Linear(d_model, vocab_size)
|
|
||||||
self.src_mask = None
|
|
||||||
|
|
||||||
def _generate_square_subsequent_mask(self, sz):
|
# Apply repetition penalty
|
||||||
mask = torch.triu(torch.ones(sz, sz) == 1).transpose(0, 1)
|
for token in set(generated_tokens):
|
||||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
if generated_tokens.count(token) > 1:
|
||||||
return mask
|
logits[0, token] /= (repetition_penalty + generated_tokens.count(token) * 0.02) # Frequency-based scaling for penalty
|
||||||
|
|
||||||
def forward(self, src):
|
# Apply slight logits smoothing to avoid overly confident peaks
|
||||||
# src: [seq_len, batch_size]
|
logits = logits - torch.mean(logits) * 0.01
|
||||||
src = src.transpose(0, 1) # Shape: [seq_len, batch_size]
|
|
||||||
src = self.embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
|
|
||||||
src = self.pos_encoder(src)
|
|
||||||
if self.src_mask is None or self.src_mask.size(0) != src.size(0):
|
|
||||||
device = src.device
|
|
||||||
self.src_mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
|
|
||||||
output = self.transformer_encoder(src, self.src_mask)
|
|
||||||
logits = self.fc_out(output)
|
|
||||||
return logits # Shape: [seq_len, batch_size, vocab_size]
|
|
||||||
|
|
||||||
|
# Dynamic Nucleus (top-p) sampling with adjusted threshold
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cumulative_probs = torch.cumsum(self.softmax(sorted_logits), dim=-1)
|
||||||
|
top_p_mask = cumulative_probs < top_p
|
||||||
|
top_p_logits = sorted_logits[top_p_mask]
|
||||||
|
top_p_indices = sorted_indices[top_p_mask]
|
||||||
|
|
||||||
# Training function
|
if len(top_p_logits) > 1:
|
||||||
def train_step(model, optimizer, criterion, input_tensor, target_tensor):
|
top_p_probs = self.softmax(top_p_logits)
|
||||||
model.train()
|
sampled_token = top_p_indices[torch.multinomial(top_p_probs, num_samples=1).item()].item()
|
||||||
optimizer.zero_grad()
|
else:
|
||||||
output = model(input_tensor) # [seq_len, batch_size, vocab_size]
|
sampled_token = sorted_indices[0, 0].item() # Fallback to the most probable token if none pass the top-p threshold
|
||||||
output = output.view(-1, output.size(-1)) # [seq_len * batch_size, vocab_size]
|
|
||||||
target = target_tensor.transpose(0,1).contiguous().view(-1) # [seq_len * batch_size]
|
|
||||||
loss = criterion(output, target)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
print(f'Training loss: {loss.item():.4f}')
|
|
||||||
return loss.item()
|
|
||||||
|
|
||||||
|
# Enforce diversity constraint by limiting token frequency
|
||||||
|
if generated_tokens.count(sampled_token) >= max_token_frequency:
|
||||||
|
logits[0, sampled_token] -= 1.5 # Adjusted penalty to limit token frequency
|
||||||
|
continue # Skip adding this token if it has reached the max frequency
|
||||||
|
|
||||||
def initialize_model(tokenizer, device):
|
# Stop repetition if the sampled token was recently repeated
|
||||||
vocab_size = len(tokenizer.token2idx)
|
if len(generated_tokens) > 1 and generated_tokens[-1] == sampled_token:
|
||||||
model = GPTModel(vocab_size=vocab_size).to(device)
|
continue
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
||||||
criterion = nn.CrossEntropyLoss(ignore_index=0)
|
|
||||||
model_path = 'gpt_model.pth'
|
|
||||||
|
|
||||||
# Load existing model if available
|
# Add token and update state
|
||||||
if os.path.exists(model_path):
|
generated_tokens.append(sampled_token)
|
||||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
recent_tokens.append(sampled_token)
|
||||||
print('Model loaded from', model_path)
|
if len(recent_tokens) > 10:
|
||||||
else:
|
recent_tokens.pop(0) # Maintain a window of recent tokens to suppress
|
||||||
print('No existing model found. Starting fresh.')
|
|
||||||
return model, optimizer, criterion
|
|
||||||
|
|
||||||
|
# Update input tensor to include the generated token
|
||||||
|
input_tensor = torch.tensor(generated_tokens).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
def save_model(model):
|
# Gradually decrease temperature to reduce randomness more smoothly
|
||||||
model_path = 'gpt_model.pth'
|
temperature = max(0.75, temperature * 0.98)
|
||||||
torch.save(model.state_dict(), model_path)
|
|
||||||
|
|
||||||
|
response = self.detokenize(generated_tokens)
|
||||||
|
print("[DEBUG] Generated response:", response) # Debug statement to verify changes
|
||||||
|
print(f"[DEBUG] Generation loss rate (approximated): {temperature}") # Approximate loss rate
|
||||||
|
return response
|
||||||
|
|
||||||
def update_model_vocab(model, tokenizer, device):
|
def tokenize(self, text):
|
||||||
vocab_size = len(tokenizer.token2idx)
|
# Character-level tokenizer: converts text to ASCII values
|
||||||
|
token_ids = [ord(char) for char in text if ord(char) < self.vocab_size]
|
||||||
|
return token_ids
|
||||||
|
|
||||||
old_embedding_weight = model.embedding.weight.data
|
def detokenize(self, token_ids):
|
||||||
old_vocab_size, embedding_dim = old_embedding_weight.shape
|
# Detokenizer to convert ASCII values back to characters
|
||||||
new_embedding = nn.Embedding(vocab_size, model.d_model).to(device)
|
return "".join([chr(id) for id in token_ids])
|
||||||
new_embedding.weight.data[:old_vocab_size] = old_embedding_weight
|
|
||||||
model.embedding = new_embedding
|
|
||||||
|
|
||||||
old_fc_out_weight = model.fc_out.weight.data
|
def train_on_message(self, message):
|
||||||
old_fc_out_bias = model.fc_out.bias.data
|
# Tokenize the message
|
||||||
new_fc_out = nn.Linear(model.d_model, vocab_size).to(device)
|
input_ids = self.tokenize(message)
|
||||||
new_fc_out.weight.data[:old_vocab_size] = old_fc_out_weight
|
input_tensor = torch.tensor(input_ids).unsqueeze(0).to(self.device)
|
||||||
new_fc_out.bias.data[:old_vocab_size] = old_fc_out_bias
|
|
||||||
model.fc_out = new_fc_out
|
|
||||||
|
|
||||||
return model
|
# Create target labels (next character prediction task)
|
||||||
|
labels = input_ids[1:] + [input_ids[-1]] # Shift tokens for training
|
||||||
|
labels_tensor = torch.tensor(labels).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
# Training step
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
outputs = self.forward(input_tensor)
|
||||||
|
loss = self.criterion(outputs.view(-1, outputs.size(-1)), labels_tensor.view(-1))
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
print(f"Training loss: {loss.item()}")
|
||||||
|
|
||||||
def train_on_conversation(model, optimizer, criterion, tokenizer, input_text, target_text, device):
|
# Changes made:
|
||||||
tokenizer.build_vocab([input_text, target_text])
|
# Version: Jade-Solstice-Horizon
|
||||||
input_indices = tokenizer.encode(input_text)
|
# - Reverted temperature, top_p, and repetition penalty settings to be closer to Solstice.
|
||||||
target_indices = tokenizer.encode(target_text)
|
# - Introduced explicit stop criteria to prevent repeating tokens consecutively.
|
||||||
|
# - Applied slight smoothing to logits to prevent high peaks and excessive repetition.
|
||||||
|
# - Updated debug message to reflect the new version.
|
||||||
|
|
||||||
# Concatenate input and target indices to create a single sequence
|
# Observations:
|
||||||
full_indices = input_indices + target_indices
|
# - Aimed to retain the strengths of Solstice while reducing remaining issues with repetitive tokens by adding specific repetition stop criteria.
|
||||||
|
|
||||||
# Create input and target sequences for training
|
|
||||||
input_sequence = full_indices[:-1] # All tokens except the last
|
|
||||||
target_sequence = full_indices[1:] # All tokens except the first
|
|
||||||
|
|
||||||
# Update model if vocabulary has changed
|
|
||||||
if len(tokenizer.token2idx) != model.embedding.num_embeddings:
|
|
||||||
model = update_model_vocab(model, tokenizer, device)
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
||||||
|
|
||||||
input_tensor = torch.tensor([input_sequence], dtype=torch.long, device=device)
|
|
||||||
target_tensor = torch.tensor([target_sequence], dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
loss = train_step(model, optimizer, criterion, input_tensor, target_tensor)
|
|
||||||
return loss
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user