Fix: Added code to allow for other data sources to be added

This commit is contained in:
Dan
2024-06-08 09:21:19 -04:00
parent 1fe54ed1ff
commit e3e4b7abe6
3 changed files with 138 additions and 5 deletions

View File

@ -171,7 +171,13 @@ def check_input_chars(s, string_to_int):
return unknown_chars
# Maintain conversation history
conversation_history = []
def process_message(message):
global conversation_history
print(f"Processing message: '{message}'")
if not message.strip():
print("Message is empty or invalid.")
@ -182,27 +188,40 @@ def process_message(message):
print(f"Message contains unknown characters: {unknown_chars}")
return f"Message contains unknown characters: {unknown_chars}"
# Add the new message to the conversation history
conversation_history.append(message)
# Limit the conversation history to the last 5 messages to avoid excessive length
if len(conversation_history) > 5:
conversation_history = conversation_history[-5:]
# Concatenate the conversation history to form the input prompt
context = " ".join(conversation_history)
encoded_text = torch.tensor(
[encode(message, string_to_int)], dtype=torch.long
[encode(context, string_to_int)], dtype=torch.long
).to(device)
print(f"Encoded text shape: {encoded_text.shape}")
if encoded_text.size(1) == 0:
print("Message could not be processed.")
return "Message could not be processed."
with torch.no_grad():
generated_tokens = model.generate(
encoded_text, max_new_tokens=50, temperature=0.7
encoded_text, max_new_tokens=100, temperature=1.0
)
generated_tokens = generated_tokens[0, len(encoded_text[0]) :]
decoded_response = decode(generated_tokens.tolist(), int_to_string)
print(f"Generated response: '{decoded_response}'")
if decoded_response.startswith(message):
decoded_response = decoded_response[len(message) :].strip()
if decoded_response.startswith(context):
decoded_response = decoded_response[len(context) :].strip()
print(f"Final response: '{decoded_response}'")
# Add the response to the conversation history
conversation_history.append(decoded_response)
return decoded_response