import torch from torch.optim import Adam from torch.nn import CrossEntropyLoss from memory_buffer import MemoryBuffer from model import TinyGPT from tokenizer import simple_tokenizer, detokenizer, load_vocab class ModelManager: def __init__(self, use_custom_model=True, ollama_url=None): self.use_custom_model = use_custom_model self.ollama_url = ollama_url self.memory = MemoryBuffer(capacity=10) # Memory for 10 recent interactions if self.use_custom_model: self._load_custom_model() def _load_custom_model(self): """Load the custom GPT model.""" self.vocab = load_vocab() self.model = TinyGPT(vocab_size=len(self.vocab), embed_size=32, num_heads=2, num_layers=2).cuda() self.model.load_state_dict(torch.load("ruby_model.pth", weights_only=True)) self.model.eval() self.optimizer = Adam(self.model.parameters(), lr=0.0001) self.criterion = CrossEntropyLoss() def query_custom_model(self, input_text): """Generate a response using the custom GPT model.""" tokens = torch.tensor(simple_tokenizer(input_text, self.vocab), dtype=torch.long).cuda() with torch.no_grad(): output = self.model(tokens.unsqueeze(0), tokens.unsqueeze(0)) predicted_idx = output.argmax(-1).squeeze()[-1].item() return detokenizer([predicted_idx], self.vocab) def train_on_interaction(self, user_input, bot_response): """Train the model on a single interaction.""" self.model.train() input_tokens = torch.tensor(simple_tokenizer(user_input, self.vocab), dtype=torch.long).cuda() target_tokens = torch.tensor(simple_tokenizer(bot_response, self.vocab), dtype=torch.long).cuda() # Padding to ensure equal lengths max_len = max(len(input_tokens), len(target_tokens)) input_tokens = torch.cat([input_tokens, torch.zeros(max_len - len(input_tokens), dtype=torch.long).cuda()]) target_tokens = torch.cat([target_tokens, torch.zeros(max_len - len(target_tokens), dtype=torch.long).cuda()]) # Perform a single training step self.optimizer.zero_grad() output = self.model(input_tokens.unsqueeze(0), target_tokens.unsqueeze(0)) loss = self.criterion(output.view(-1, len(self.vocab)), target_tokens.view(-1)) loss.backward() self.optimizer.step() self.model.eval() def generate_response(self, input_text): """Generate a response using the selected model.""" if self.use_custom_model: bot_response = self.query_custom_model(input_text) self.memory.add_interaction(input_text, bot_response) self.train_on_interaction(input_text, bot_response) return bot_response elif self.ollama_url: return self.query_ollama(input_text) else: raise ValueError("No valid model selected or configured.")