124 lines
4.7 KiB
Python
124 lines
4.7 KiB
Python
import discord
|
|
import torch
|
|
from model import SimpleTokenizer, initialize_model, train_on_conversation, save_model, update_model_vocab
|
|
import torch.nn.functional as F
|
|
import os
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
class DiscordBot(discord.Client):
|
|
def __init__(self, **options):
|
|
super().__init__(**options)
|
|
self.tokenizer = SimpleTokenizer()
|
|
self.tokenizer_vocab_path = 'tokenizer_vocab.json'
|
|
self.tokenizer.load_vocab(self.tokenizer_vocab_path)
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
self.model, self.optimizer, self.criterion = initialize_model(self.tokenizer, self.device)
|
|
self.conversation_history = [] # Keep track of conversations for learning
|
|
self.previous_reply = None # Store last reply for pattern recognition
|
|
|
|
async def on_ready(self):
|
|
print(f'Logged in as {self.user.name}')
|
|
|
|
async def on_message(self, message):
|
|
if message.author == self.user:
|
|
return
|
|
|
|
print(f"Received message from {message.author}: {message.content}")
|
|
|
|
# Update tokenizer vocabulary with the new message
|
|
previous_vocab_size = len(self.tokenizer.token2idx)
|
|
self.tokenizer.build_vocab([message.content])
|
|
new_vocab_size = len(self.tokenizer.token2idx)
|
|
|
|
# Update model if vocabulary has changed
|
|
if new_vocab_size != previous_vocab_size:
|
|
self.model = update_model_vocab(self.model, self.tokenizer, self.device)
|
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
|
print("Model vocabulary updated.")
|
|
|
|
# Generate a reply
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
reply = self.generate_reply(message.content)
|
|
print(f"Sending reply: {reply}")
|
|
await message.channel.send(reply)
|
|
|
|
# Append conversation to history for future learning
|
|
self.conversation_history.append({
|
|
"user_message": message.content,
|
|
"bot_reply": reply,
|
|
"channel": message.channel
|
|
})
|
|
|
|
# Continuous learning: Train on this conversation pair
|
|
loss = train_on_conversation(
|
|
self.model,
|
|
self.optimizer,
|
|
self.criterion,
|
|
self.tokenizer,
|
|
message.content,
|
|
reply,
|
|
self.device
|
|
)
|
|
|
|
# Save the model and tokenizer for future sessions
|
|
save_model(self.model)
|
|
self.tokenizer.save_vocab(self.tokenizer_vocab_path)
|
|
|
|
# Store this reply to help Jade learn from repetition in future responses
|
|
self.previous_reply = reply
|
|
|
|
def generate_reply(self, input_text, max_length=20, temperature=1.0, top_k=10):
|
|
# Prepare the input sequence with special tokens
|
|
input_sequence = ['<SOS>'] + input_text.split() + ['<EOS>']
|
|
input_indices = self.tokenizer.encode(' '.join(input_sequence))
|
|
input_tensor = torch.tensor([input_indices], dtype=torch.long, device=self.device)
|
|
|
|
generated_indices = []
|
|
for _ in range(max_length):
|
|
output = self.model(input_tensor)
|
|
if output.size(0) == 0:
|
|
print("Model output is empty. Breaking out of generation loop.")
|
|
break
|
|
next_token_logits = output[-1, 0, :] / temperature
|
|
|
|
# Penalize <UNK>
|
|
unk_token_idx = self.tokenizer.token2idx.get('<UNK>', None)
|
|
if unk_token_idx is not None:
|
|
next_token_logits[unk_token_idx] = -float('inf')
|
|
|
|
# Apply Top-K sampling
|
|
top_k = min(top_k, next_token_logits.size(-1))
|
|
values, indices = torch.topk(next_token_logits, top_k)
|
|
probabilities = F.softmax(values, dim=-1)
|
|
predicted_index = indices[torch.multinomial(probabilities, 1)].item()
|
|
|
|
# Stop if <EOS> token is generated
|
|
if predicted_index == self.tokenizer.token2idx.get('<EOS>'):
|
|
break
|
|
|
|
generated_indices.append(predicted_index)
|
|
input_indices.append(predicted_index)
|
|
input_tensor = torch.tensor([input_indices], dtype=torch.long, device=self.device)
|
|
|
|
# Filter out special tokens from generated indices
|
|
special_token_indices = set(self.tokenizer.token2idx[token] for token in ['<PAD>', '<UNK>', '<SOS>', '<EOS>'])
|
|
filtered_indices = [idx for idx in generated_indices if idx not in special_token_indices]
|
|
|
|
# Decode the filtered indices
|
|
reply = self.tokenizer.decode(filtered_indices)
|
|
return reply
|
|
|
|
|
|
DISCORD_TOKEN = os.getenv('DISCORD_TOKEN')
|
|
|
|
# Initialize and run the Discord bot
|
|
intents = discord.Intents.default()
|
|
intents.message_content = True
|
|
bot = DiscordBot(intents=intents)
|
|
bot.run(DISCORD_TOKEN)
|