Fix: Added code to allow for other data sources to be added
This commit is contained in:
@ -171,7 +171,13 @@ def check_input_chars(s, string_to_int):
|
||||
return unknown_chars
|
||||
|
||||
|
||||
# Maintain conversation history
|
||||
conversation_history = []
|
||||
|
||||
|
||||
def process_message(message):
|
||||
global conversation_history
|
||||
|
||||
print(f"Processing message: '{message}'")
|
||||
if not message.strip():
|
||||
print("Message is empty or invalid.")
|
||||
@ -182,27 +188,40 @@ def process_message(message):
|
||||
print(f"Message contains unknown characters: {unknown_chars}")
|
||||
return f"Message contains unknown characters: {unknown_chars}"
|
||||
|
||||
# Add the new message to the conversation history
|
||||
conversation_history.append(message)
|
||||
# Limit the conversation history to the last 5 messages to avoid excessive length
|
||||
if len(conversation_history) > 5:
|
||||
conversation_history = conversation_history[-5:]
|
||||
|
||||
# Concatenate the conversation history to form the input prompt
|
||||
context = " ".join(conversation_history)
|
||||
encoded_text = torch.tensor(
|
||||
[encode(message, string_to_int)], dtype=torch.long
|
||||
[encode(context, string_to_int)], dtype=torch.long
|
||||
).to(device)
|
||||
print(f"Encoded text shape: {encoded_text.shape}")
|
||||
|
||||
if encoded_text.size(1) == 0:
|
||||
print("Message could not be processed.")
|
||||
return "Message could not be processed."
|
||||
|
||||
with torch.no_grad():
|
||||
generated_tokens = model.generate(
|
||||
encoded_text, max_new_tokens=50, temperature=0.7
|
||||
encoded_text, max_new_tokens=100, temperature=1.0
|
||||
)
|
||||
generated_tokens = generated_tokens[0, len(encoded_text[0]) :]
|
||||
|
||||
decoded_response = decode(generated_tokens.tolist(), int_to_string)
|
||||
print(f"Generated response: '{decoded_response}'")
|
||||
|
||||
if decoded_response.startswith(message):
|
||||
decoded_response = decoded_response[len(message) :].strip()
|
||||
if decoded_response.startswith(context):
|
||||
decoded_response = decoded_response[len(context) :].strip()
|
||||
|
||||
print(f"Final response: '{decoded_response}'")
|
||||
|
||||
# Add the response to the conversation history
|
||||
conversation_history.append(decoded_response)
|
||||
|
||||
return decoded_response
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user