387 lines
14 KiB
Python
387 lines
14 KiB
Python
"""
|
|
Conversation State Management for Mai
|
|
|
|
Provides turn-by-turn conversation history with proper session isolation,
|
|
interruption handling, and context window management.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import threading
|
|
import uuid
|
|
from typing import Dict, List, Optional, Any
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
|
|
# Import existing conversation models for consistency
|
|
try:
|
|
from ..models.conversation import Message, Conversation
|
|
except ImportError:
|
|
# Fallback if models not available yet
|
|
Message = None
|
|
Conversation = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ConversationTurn:
|
|
"""Single conversation turn with comprehensive metadata."""
|
|
|
|
conversation_id: str
|
|
user_message: str
|
|
ai_response: str
|
|
timestamp: float
|
|
model_used: str
|
|
tokens_used: int
|
|
response_time: float
|
|
memory_context_applied: bool
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for serialization."""
|
|
return {
|
|
"conversation_id": self.conversation_id,
|
|
"user_message": self.user_message,
|
|
"ai_response": self.ai_response,
|
|
"timestamp": self.timestamp,
|
|
"model_used": self.model_used,
|
|
"tokens_used": self.tokens_used,
|
|
"response_time": self.response_time,
|
|
"memory_context_applied": self.memory_context_applied,
|
|
}
|
|
|
|
|
|
class ConversationState:
|
|
"""
|
|
Manages conversation state across multiple sessions with proper isolation.
|
|
|
|
Provides turn-by-turn history tracking, automatic cleanup,
|
|
thread-safe operations, and Ollama-compatible formatting.
|
|
"""
|
|
|
|
def __init__(self, max_turns_per_conversation: int = 10):
|
|
"""
|
|
Initialize conversation state manager.
|
|
|
|
Args:
|
|
max_turns_per_conversation: Maximum turns to keep per conversation
|
|
"""
|
|
self.conversations: Dict[str, List[ConversationTurn]] = {}
|
|
self.max_turns = max_turns_per_conversation
|
|
self._lock = threading.RLock() # Reentrant lock for nested calls
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
self.logger.info(
|
|
f"ConversationState initialized with max {max_turns_per_conversation} turns per conversation"
|
|
)
|
|
|
|
def add_turn(self, turn: ConversationTurn) -> None:
|
|
"""
|
|
Add a conversation turn with automatic timestamp and cleanup.
|
|
|
|
Args:
|
|
turn: ConversationTurn to add
|
|
"""
|
|
with self._lock:
|
|
conversation_id = turn.conversation_id
|
|
|
|
# Initialize conversation if doesn't exist
|
|
if conversation_id not in self.conversations:
|
|
self.conversations[conversation_id] = []
|
|
self.logger.debug(f"Created new conversation: {conversation_id}")
|
|
|
|
# Add the turn
|
|
self.conversations[conversation_id].append(turn)
|
|
self.logger.debug(
|
|
f"Added turn to conversation {conversation_id}: {turn.tokens_used} tokens, {turn.response_time:.2f}s"
|
|
)
|
|
|
|
# Automatic cleanup: maintain last N turns
|
|
if len(self.conversations[conversation_id]) > self.max_turns:
|
|
# Remove oldest turns to maintain limit
|
|
excess_count = len(self.conversations[conversation_id]) - self.max_turns
|
|
removed_turns = self.conversations[conversation_id][:excess_count]
|
|
self.conversations[conversation_id] = self.conversations[conversation_id][
|
|
excess_count:
|
|
]
|
|
|
|
self.logger.debug(
|
|
f"Cleaned up {excess_count} old turns from conversation {conversation_id}"
|
|
)
|
|
|
|
# Log removed turns for debugging
|
|
for removed_turn in removed_turns:
|
|
self.logger.debug(
|
|
f"Removed turn: {removed_turn.timestamp} - {removed_turn.user_message[:50]}..."
|
|
)
|
|
|
|
def get_history(self, conversation_id: str) -> List[Dict[str, str]]:
|
|
"""
|
|
Get conversation history in Ollama-compatible format.
|
|
|
|
Args:
|
|
conversation_id: ID of conversation to retrieve
|
|
|
|
Returns:
|
|
List of message dictionaries formatted for Ollama API
|
|
"""
|
|
with self._lock:
|
|
turns = self.conversations.get(conversation_id, [])
|
|
|
|
# Convert to Ollama format: alternating user/assistant roles
|
|
history = []
|
|
for turn in turns:
|
|
history.append({"role": "user", "content": turn.user_message})
|
|
history.append({"role": "assistant", "content": turn.ai_response})
|
|
|
|
self.logger.debug(
|
|
f"Retrieved {len(history)} messages from conversation {conversation_id}"
|
|
)
|
|
return history
|
|
|
|
def set_conversation_history(
|
|
self, messages: List[Dict[str, str]], conversation_id: Optional[str] = None
|
|
) -> None:
|
|
"""
|
|
Restore conversation history from session storage.
|
|
|
|
Args:
|
|
messages: List of message dictionaries in Ollama format [{"role": "user/assistant", "content": "..."}]
|
|
conversation_id: Optional conversation ID to restore to (creates new if None)
|
|
"""
|
|
with self._lock:
|
|
if conversation_id is None:
|
|
conversation_id = str(uuid.uuid4())
|
|
|
|
# Clear existing conversation for this ID
|
|
self.conversations[conversation_id] = []
|
|
|
|
# Convert messages back to ConversationTurn objects
|
|
# Messages should be in pairs: user, assistant, user, assistant, ...
|
|
i = 0
|
|
while i < len(messages):
|
|
# Expect user message first
|
|
if i >= len(messages) or messages[i].get("role") != "user":
|
|
self.logger.warning(f"Expected user message at index {i}, skipping")
|
|
i += 1
|
|
continue
|
|
|
|
user_message = messages[i].get("content", "")
|
|
i += 1
|
|
|
|
# Expect assistant message next
|
|
if i >= len(messages) or messages[i].get("role") != "assistant":
|
|
self.logger.warning(f"Expected assistant message at index {i}, skipping")
|
|
continue
|
|
|
|
ai_response = messages[i].get("content", "")
|
|
i += 1
|
|
|
|
# Create ConversationTurn with estimated metadata
|
|
turn = ConversationTurn(
|
|
conversation_id=conversation_id,
|
|
user_message=user_message,
|
|
ai_response=ai_response,
|
|
timestamp=time.time(), # Use current time as approximation
|
|
model_used="restored", # Indicate this is from restoration
|
|
tokens_used=0, # Token count not available from session
|
|
response_time=0.0, # Response time not available from session
|
|
memory_context_applied=False, # Memory context not tracked in session
|
|
)
|
|
|
|
self.conversations[conversation_id].append(turn)
|
|
|
|
self.logger.info(
|
|
f"Restored {len(self.conversations[conversation_id])} turns to conversation {conversation_id}"
|
|
)
|
|
|
|
def get_last_n_turns(self, conversation_id: str, n: int = 5) -> List[ConversationTurn]:
|
|
"""
|
|
Get the last N turns from a conversation.
|
|
|
|
Args:
|
|
conversation_id: ID of conversation
|
|
n: Number of recent turns to retrieve
|
|
|
|
Returns:
|
|
List of last N ConversationTurn objects
|
|
"""
|
|
with self._lock:
|
|
turns = self.conversations.get(conversation_id, [])
|
|
return turns[-n:] if n > 0 else []
|
|
|
|
def clear_pending_response(self, conversation_id: str) -> None:
|
|
"""
|
|
Clear any pending response for interruption handling.
|
|
|
|
Args:
|
|
conversation_id: ID of conversation to clear
|
|
"""
|
|
with self._lock:
|
|
if conversation_id in self.conversations:
|
|
# Find and remove incomplete turns (those without AI response)
|
|
original_count = len(self.conversations[conversation_id])
|
|
self.conversations[conversation_id] = [
|
|
turn
|
|
for turn in self.conversations[conversation_id]
|
|
if turn.ai_response.strip() # Must have AI response
|
|
]
|
|
|
|
removed_count = original_count - len(self.conversations[conversation_id])
|
|
if removed_count > 0:
|
|
self.logger.info(
|
|
f"Cleared {removed_count} incomplete turns from conversation {conversation_id}"
|
|
)
|
|
|
|
def start_conversation(self, conversation_id: Optional[str] = None) -> str:
|
|
"""
|
|
Start a new conversation or return existing ID.
|
|
|
|
Args:
|
|
conversation_id: Optional existing conversation ID
|
|
|
|
Returns:
|
|
Conversation ID (new or existing)
|
|
"""
|
|
with self._lock:
|
|
if conversation_id is None:
|
|
conversation_id = str(uuid.uuid4())
|
|
|
|
if conversation_id not in self.conversations:
|
|
self.conversations[conversation_id] = []
|
|
self.logger.debug(f"Started new conversation: {conversation_id}")
|
|
|
|
return conversation_id
|
|
|
|
def is_processing(self, conversation_id: str) -> bool:
|
|
"""
|
|
Check if conversation is currently being processed.
|
|
|
|
Args:
|
|
conversation_id: ID of conversation
|
|
|
|
Returns:
|
|
True if currently processing, False otherwise
|
|
"""
|
|
with self._lock:
|
|
return hasattr(self, "_processing_locks") and conversation_id in getattr(
|
|
self, "_processing_locks", {}
|
|
)
|
|
|
|
def set_processing(self, conversation_id: str, processing: bool) -> None:
|
|
"""
|
|
Set processing lock for conversation.
|
|
|
|
Args:
|
|
conversation_id: ID of conversation
|
|
processing: Processing state
|
|
"""
|
|
with self._lock:
|
|
if not hasattr(self, "_processing_locks"):
|
|
self._processing_locks = {}
|
|
self._processing_locks[conversation_id] = processing
|
|
|
|
def get_conversation_turns(self, conversation_id: str) -> List[ConversationTurn]:
|
|
"""
|
|
Get all turns for a conversation.
|
|
|
|
Args:
|
|
conversation_id: ID of conversation
|
|
|
|
Returns:
|
|
List of ConversationTurn objects
|
|
"""
|
|
with self._lock:
|
|
return self.conversations.get(conversation_id, [])
|
|
|
|
def delete_conversation(self, conversation_id: str) -> bool:
|
|
"""
|
|
Delete a conversation completely.
|
|
|
|
Args:
|
|
conversation_id: ID of conversation to delete
|
|
|
|
Returns:
|
|
True if conversation was deleted, False if not found
|
|
"""
|
|
with self._lock:
|
|
if conversation_id in self.conversations:
|
|
del self.conversations[conversation_id]
|
|
self.logger.info(f"Deleted conversation: {conversation_id}")
|
|
return True
|
|
return False
|
|
|
|
def list_conversations(self) -> List[str]:
|
|
"""
|
|
List all active conversation IDs.
|
|
|
|
Returns:
|
|
List of conversation IDs
|
|
"""
|
|
with self._lock:
|
|
return list(self.conversations.keys())
|
|
|
|
def get_conversation_stats(self, conversation_id: str) -> Dict[str, Any]:
|
|
"""
|
|
Get statistics for a specific conversation.
|
|
|
|
Args:
|
|
conversation_id: ID of conversation
|
|
|
|
Returns:
|
|
Dictionary with conversation statistics
|
|
"""
|
|
with self._lock:
|
|
turns = self.conversations.get(conversation_id, [])
|
|
|
|
if not turns:
|
|
return {
|
|
"turn_count": 0,
|
|
"total_tokens": 0,
|
|
"total_response_time": 0.0,
|
|
"average_response_time": 0.0,
|
|
"average_tokens": 0.0,
|
|
}
|
|
|
|
total_tokens = sum(turn.tokens_used for turn in turns)
|
|
total_response_time = sum(turn.response_time for turn in turns)
|
|
avg_response_time = total_response_time / len(turns)
|
|
avg_tokens = total_tokens / len(turns)
|
|
|
|
return {
|
|
"turn_count": len(turns),
|
|
"total_tokens": total_tokens,
|
|
"total_response_time": total_response_time,
|
|
"average_response_time": avg_response_time,
|
|
"average_tokens": avg_tokens,
|
|
"oldest_timestamp": min(turn.timestamp for turn in turns),
|
|
"newest_timestamp": max(turn.timestamp for turn in turns),
|
|
}
|
|
|
|
def cleanup_old_conversations(self, max_age_hours: float = 24.0) -> int:
|
|
"""
|
|
Clean up conversations older than specified age.
|
|
|
|
Args:
|
|
max_age_hours: Maximum age in hours before cleanup
|
|
|
|
Returns:
|
|
Number of conversations cleaned up
|
|
"""
|
|
with self._lock:
|
|
current_time = time.time()
|
|
cutoff_time = current_time - (max_age_hours * 3600)
|
|
|
|
conversations_to_remove = []
|
|
for conv_id, turns in self.conversations.items():
|
|
if turns and turns[-1].timestamp < cutoff_time:
|
|
conversations_to_remove.append(conv_id)
|
|
|
|
for conv_id in conversations_to_remove:
|
|
del self.conversations[conv_id]
|
|
|
|
if conversations_to_remove:
|
|
self.logger.info(f"Cleaned up {len(conversations_to_remove)} old conversations")
|
|
|
|
return len(conversations_to_remove)
|