63 lines
2.9 KiB
Python
63 lines
2.9 KiB
Python
import torch
|
|
from torch.optim import Adam
|
|
from torch.nn import CrossEntropyLoss
|
|
from memory_buffer import MemoryBuffer
|
|
from model import TinyGPT
|
|
from tokenizer import simple_tokenizer, detokenizer, load_vocab
|
|
|
|
class ModelManager:
|
|
def __init__(self, use_custom_model=True, ollama_url=None):
|
|
self.use_custom_model = use_custom_model
|
|
self.ollama_url = ollama_url
|
|
self.memory = MemoryBuffer(capacity=10) # Memory for 10 recent interactions
|
|
if self.use_custom_model:
|
|
self._load_custom_model()
|
|
|
|
def _load_custom_model(self):
|
|
"""Load the custom GPT model."""
|
|
self.vocab = load_vocab()
|
|
self.model = TinyGPT(vocab_size=len(self.vocab), embed_size=32, num_heads=2, num_layers=2).cuda()
|
|
self.model.load_state_dict(torch.load("ruby_model.pth", weights_only=True))
|
|
self.model.eval()
|
|
self.optimizer = Adam(self.model.parameters(), lr=0.0001)
|
|
self.criterion = CrossEntropyLoss()
|
|
|
|
def query_custom_model(self, input_text):
|
|
"""Generate a response using the custom GPT model."""
|
|
tokens = torch.tensor(simple_tokenizer(input_text, self.vocab), dtype=torch.long).cuda()
|
|
with torch.no_grad():
|
|
output = self.model(tokens.unsqueeze(0), tokens.unsqueeze(0))
|
|
predicted_idx = output.argmax(-1).squeeze()[-1].item()
|
|
return detokenizer([predicted_idx], self.vocab)
|
|
|
|
def train_on_interaction(self, user_input, bot_response):
|
|
"""Train the model on a single interaction."""
|
|
self.model.train()
|
|
input_tokens = torch.tensor(simple_tokenizer(user_input, self.vocab), dtype=torch.long).cuda()
|
|
target_tokens = torch.tensor(simple_tokenizer(bot_response, self.vocab), dtype=torch.long).cuda()
|
|
|
|
# Padding to ensure equal lengths
|
|
max_len = max(len(input_tokens), len(target_tokens))
|
|
input_tokens = torch.cat([input_tokens, torch.zeros(max_len - len(input_tokens), dtype=torch.long).cuda()])
|
|
target_tokens = torch.cat([target_tokens, torch.zeros(max_len - len(target_tokens), dtype=torch.long).cuda()])
|
|
|
|
# Perform a single training step
|
|
self.optimizer.zero_grad()
|
|
output = self.model(input_tokens.unsqueeze(0), target_tokens.unsqueeze(0))
|
|
loss = self.criterion(output.view(-1, len(self.vocab)), target_tokens.view(-1))
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
self.model.eval()
|
|
|
|
def generate_response(self, input_text):
|
|
"""Generate a response using the selected model."""
|
|
if self.use_custom_model:
|
|
bot_response = self.query_custom_model(input_text)
|
|
self.memory.add_interaction(input_text, bot_response)
|
|
self.train_on_interaction(input_text, bot_response)
|
|
return bot_response
|
|
elif self.ollama_url:
|
|
return self.query_ollama(input_text)
|
|
else:
|
|
raise ValueError("No valid model selected or configured.")
|