Emerald/model.py

184 lines
7.2 KiB
Python

import torch
import torch.nn as nn
import threading
import os
import json
# Simple Tokenizer
class SimpleTokenizer:
def __init__(self):
self.token2idx = {'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3}
self.idx2token = {idx: token for token, idx in self.token2idx.items()}
self.lock = threading.Lock()
def build_vocab(self, texts):
with self.lock:
for text in texts:
tokens = text.split()
for token in tokens:
if token not in self.token2idx:
idx = len(self.token2idx)
self.token2idx[token] = idx
self.idx2token[idx] = token
def encode(self, text):
with self.lock:
return [self.token2idx.get(token, self.token2idx['<UNK>']) for token in text.split()]
def decode(self, indices):
with self.lock:
return ' '.join([self.idx2token.get(idx, '<UNK>') for idx in indices])
def save_vocab(self, path):
with open(path, 'w') as f:
json.dump({'token2idx': self.token2idx, 'idx2token': self.idx2token}, f)
def load_vocab(self, path):
if os.path.exists(path):
with open(path, 'r') as f:
vocab = json.load(f)
self.token2idx = vocab['token2idx']
self.idx2token = {int(k): v for k, v in vocab['idx2token'].items()}
print('Tokenizer vocabulary loaded from', path)
# Ensure special tokens are present
special_tokens = {'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3}
for token, idx in special_tokens.items():
if token not in self.token2idx:
self.token2idx[token] = idx
self.idx2token[idx] = token
else:
print('No existing tokenizer vocabulary found. Starting fresh.')
self.token2idx = {'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3}
self.idx2token = {idx: token for token, idx in self.token2idx.items()}
# Positional Encoding
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
if d_model % 2 == 1:
pe[:, -1] = torch.cos(position.squeeze() * div_term[-1])
else:
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1) # Shape: [max_len, 1, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
# x: [seq_len, batch_size, d_model]
x = x + self.pe[:x.size(0)]
return x
# GPT Model
class GPTModel(nn.Module):
def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=2):
super(GPTModel, self).__init__()
self.model_type = 'Transformer'
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model)
encoder_layers = nn.TransformerEncoderLayer(d_model, nhead)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
self.fc_out = nn.Linear(d_model, vocab_size)
self.src_mask = None
def _generate_square_subsequent_mask(self, sz):
mask = torch.triu(torch.ones(sz, sz) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, src):
# src: [seq_len, batch_size]
src = src.transpose(0, 1) # Shape: [seq_len, batch_size]
src = self.embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
src = self.pos_encoder(src)
if self.src_mask is None or self.src_mask.size(0) != src.size(0):
device = src.device
self.src_mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
output = self.transformer_encoder(src, self.src_mask)
logits = self.fc_out(output)
return logits # Shape: [seq_len, batch_size, vocab_size]
# Training function
def train_step(model, optimizer, criterion, input_tensor, target_tensor):
model.train()
optimizer.zero_grad()
output = model(input_tensor) # [seq_len, batch_size, vocab_size]
output = output.view(-1, output.size(-1)) # [seq_len * batch_size, vocab_size]
target = target_tensor.transpose(0,1).contiguous().view(-1) # [seq_len * batch_size]
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(f'Training loss: {loss.item():.4f}')
return loss.item()
def initialize_model(tokenizer, device):
vocab_size = len(tokenizer.token2idx)
model = GPTModel(vocab_size=vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0)
model_path = 'gpt_model.pth'
# Load existing model if available
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
print('Model loaded from', model_path)
else:
print('No existing model found. Starting fresh.')
return model, optimizer, criterion
def save_model(model):
model_path = 'gpt_model.pth'
torch.save(model.state_dict(), model_path)
def update_model_vocab(model, tokenizer, device):
vocab_size = len(tokenizer.token2idx)
old_embedding_weight = model.embedding.weight.data
old_vocab_size, embedding_dim = old_embedding_weight.shape
new_embedding = nn.Embedding(vocab_size, model.d_model).to(device)
new_embedding.weight.data[:old_vocab_size] = old_embedding_weight
model.embedding = new_embedding
old_fc_out_weight = model.fc_out.weight.data
old_fc_out_bias = model.fc_out.bias.data
new_fc_out = nn.Linear(model.d_model, vocab_size).to(device)
new_fc_out.weight.data[:old_vocab_size] = old_fc_out_weight
new_fc_out.bias.data[:old_vocab_size] = old_fc_out_bias
model.fc_out = new_fc_out
return model
def train_on_conversation(model, optimizer, criterion, tokenizer, input_text, target_text, device):
tokenizer.build_vocab([input_text, target_text])
input_indices = tokenizer.encode(input_text)
target_indices = tokenizer.encode(target_text)
# Concatenate input and target indices to create a single sequence
full_indices = input_indices + target_indices
# Create input and target sequences for training
input_sequence = full_indices[:-1] # All tokens except the last
target_sequence = full_indices[1:] # All tokens except the first
# Update model if vocabulary has changed
if len(tokenizer.token2idx) != model.embedding.num_embeddings:
model = update_model_vocab(model, tokenizer, device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
input_tensor = torch.tensor([input_sequence], dtype=torch.long, device=device)
target_tensor = torch.tensor([target_sequence], dtype=torch.long, device=device)
loss = train_step(model, optimizer, criterion, input_tensor, target_tensor)
return loss