Ruby/model_manager.py

63 lines
2.9 KiB
Python

import torch
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from memory_buffer import MemoryBuffer
from model import TinyGPT
from tokenizer import simple_tokenizer, detokenizer, load_vocab
class ModelManager:
def __init__(self, use_custom_model=True, ollama_url=None):
self.use_custom_model = use_custom_model
self.ollama_url = ollama_url
self.memory = MemoryBuffer(capacity=10) # Memory for 10 recent interactions
if self.use_custom_model:
self._load_custom_model()
def _load_custom_model(self):
"""Load the custom GPT model."""
self.vocab = load_vocab()
self.model = TinyGPT(vocab_size=len(self.vocab), embed_size=32, num_heads=2, num_layers=2).cuda()
self.model.load_state_dict(torch.load("ruby_model.pth", weights_only=True))
self.model.eval()
self.optimizer = Adam(self.model.parameters(), lr=0.0001)
self.criterion = CrossEntropyLoss()
def query_custom_model(self, input_text):
"""Generate a response using the custom GPT model."""
tokens = torch.tensor(simple_tokenizer(input_text, self.vocab), dtype=torch.long).cuda()
with torch.no_grad():
output = self.model(tokens.unsqueeze(0), tokens.unsqueeze(0))
predicted_idx = output.argmax(-1).squeeze()[-1].item()
return detokenizer([predicted_idx], self.vocab)
def train_on_interaction(self, user_input, bot_response):
"""Train the model on a single interaction."""
self.model.train()
input_tokens = torch.tensor(simple_tokenizer(user_input, self.vocab), dtype=torch.long).cuda()
target_tokens = torch.tensor(simple_tokenizer(bot_response, self.vocab), dtype=torch.long).cuda()
# Padding to ensure equal lengths
max_len = max(len(input_tokens), len(target_tokens))
input_tokens = torch.cat([input_tokens, torch.zeros(max_len - len(input_tokens), dtype=torch.long).cuda()])
target_tokens = torch.cat([target_tokens, torch.zeros(max_len - len(target_tokens), dtype=torch.long).cuda()])
# Perform a single training step
self.optimizer.zero_grad()
output = self.model(input_tokens.unsqueeze(0), target_tokens.unsqueeze(0))
loss = self.criterion(output.view(-1, len(self.vocab)), target_tokens.view(-1))
loss.backward()
self.optimizer.step()
self.model.eval()
def generate_response(self, input_text):
"""Generate a response using the selected model."""
if self.use_custom_model:
bot_response = self.query_custom_model(input_text)
self.memory.add_interaction(input_text, bot_response)
self.train_on_interaction(input_text, bot_response)
return bot_response
elif self.ollama_url:
return self.query_ollama(input_text)
else:
raise ValueError("No valid model selected or configured.")