Files
Mai/src/mai/conversation/state.py
2026-01-26 22:40:49 -05:00

387 lines
14 KiB
Python

"""
Conversation State Management for Mai
Provides turn-by-turn conversation history with proper session isolation,
interruption handling, and context window management.
"""
import logging
import time
import threading
import uuid
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from datetime import datetime
# Import existing conversation models for consistency
try:
from ..models.conversation import Message, Conversation
except ImportError:
# Fallback if models not available yet
Message = None
Conversation = None
logger = logging.getLogger(__name__)
@dataclass
class ConversationTurn:
"""Single conversation turn with comprehensive metadata."""
conversation_id: str
user_message: str
ai_response: str
timestamp: float
model_used: str
tokens_used: int
response_time: float
memory_context_applied: bool
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"conversation_id": self.conversation_id,
"user_message": self.user_message,
"ai_response": self.ai_response,
"timestamp": self.timestamp,
"model_used": self.model_used,
"tokens_used": self.tokens_used,
"response_time": self.response_time,
"memory_context_applied": self.memory_context_applied,
}
class ConversationState:
"""
Manages conversation state across multiple sessions with proper isolation.
Provides turn-by-turn history tracking, automatic cleanup,
thread-safe operations, and Ollama-compatible formatting.
"""
def __init__(self, max_turns_per_conversation: int = 10):
"""
Initialize conversation state manager.
Args:
max_turns_per_conversation: Maximum turns to keep per conversation
"""
self.conversations: Dict[str, List[ConversationTurn]] = {}
self.max_turns = max_turns_per_conversation
self._lock = threading.RLock() # Reentrant lock for nested calls
self.logger = logging.getLogger(__name__)
self.logger.info(
f"ConversationState initialized with max {max_turns_per_conversation} turns per conversation"
)
def add_turn(self, turn: ConversationTurn) -> None:
"""
Add a conversation turn with automatic timestamp and cleanup.
Args:
turn: ConversationTurn to add
"""
with self._lock:
conversation_id = turn.conversation_id
# Initialize conversation if doesn't exist
if conversation_id not in self.conversations:
self.conversations[conversation_id] = []
self.logger.debug(f"Created new conversation: {conversation_id}")
# Add the turn
self.conversations[conversation_id].append(turn)
self.logger.debug(
f"Added turn to conversation {conversation_id}: {turn.tokens_used} tokens, {turn.response_time:.2f}s"
)
# Automatic cleanup: maintain last N turns
if len(self.conversations[conversation_id]) > self.max_turns:
# Remove oldest turns to maintain limit
excess_count = len(self.conversations[conversation_id]) - self.max_turns
removed_turns = self.conversations[conversation_id][:excess_count]
self.conversations[conversation_id] = self.conversations[conversation_id][
excess_count:
]
self.logger.debug(
f"Cleaned up {excess_count} old turns from conversation {conversation_id}"
)
# Log removed turns for debugging
for removed_turn in removed_turns:
self.logger.debug(
f"Removed turn: {removed_turn.timestamp} - {removed_turn.user_message[:50]}..."
)
def get_history(self, conversation_id: str) -> List[Dict[str, str]]:
"""
Get conversation history in Ollama-compatible format.
Args:
conversation_id: ID of conversation to retrieve
Returns:
List of message dictionaries formatted for Ollama API
"""
with self._lock:
turns = self.conversations.get(conversation_id, [])
# Convert to Ollama format: alternating user/assistant roles
history = []
for turn in turns:
history.append({"role": "user", "content": turn.user_message})
history.append({"role": "assistant", "content": turn.ai_response})
self.logger.debug(
f"Retrieved {len(history)} messages from conversation {conversation_id}"
)
return history
def set_conversation_history(
self, messages: List[Dict[str, str]], conversation_id: Optional[str] = None
) -> None:
"""
Restore conversation history from session storage.
Args:
messages: List of message dictionaries in Ollama format [{"role": "user/assistant", "content": "..."}]
conversation_id: Optional conversation ID to restore to (creates new if None)
"""
with self._lock:
if conversation_id is None:
conversation_id = str(uuid.uuid4())
# Clear existing conversation for this ID
self.conversations[conversation_id] = []
# Convert messages back to ConversationTurn objects
# Messages should be in pairs: user, assistant, user, assistant, ...
i = 0
while i < len(messages):
# Expect user message first
if i >= len(messages) or messages[i].get("role") != "user":
self.logger.warning(f"Expected user message at index {i}, skipping")
i += 1
continue
user_message = messages[i].get("content", "")
i += 1
# Expect assistant message next
if i >= len(messages) or messages[i].get("role") != "assistant":
self.logger.warning(f"Expected assistant message at index {i}, skipping")
continue
ai_response = messages[i].get("content", "")
i += 1
# Create ConversationTurn with estimated metadata
turn = ConversationTurn(
conversation_id=conversation_id,
user_message=user_message,
ai_response=ai_response,
timestamp=time.time(), # Use current time as approximation
model_used="restored", # Indicate this is from restoration
tokens_used=0, # Token count not available from session
response_time=0.0, # Response time not available from session
memory_context_applied=False, # Memory context not tracked in session
)
self.conversations[conversation_id].append(turn)
self.logger.info(
f"Restored {len(self.conversations[conversation_id])} turns to conversation {conversation_id}"
)
def get_last_n_turns(self, conversation_id: str, n: int = 5) -> List[ConversationTurn]:
"""
Get the last N turns from a conversation.
Args:
conversation_id: ID of conversation
n: Number of recent turns to retrieve
Returns:
List of last N ConversationTurn objects
"""
with self._lock:
turns = self.conversations.get(conversation_id, [])
return turns[-n:] if n > 0 else []
def clear_pending_response(self, conversation_id: str) -> None:
"""
Clear any pending response for interruption handling.
Args:
conversation_id: ID of conversation to clear
"""
with self._lock:
if conversation_id in self.conversations:
# Find and remove incomplete turns (those without AI response)
original_count = len(self.conversations[conversation_id])
self.conversations[conversation_id] = [
turn
for turn in self.conversations[conversation_id]
if turn.ai_response.strip() # Must have AI response
]
removed_count = original_count - len(self.conversations[conversation_id])
if removed_count > 0:
self.logger.info(
f"Cleared {removed_count} incomplete turns from conversation {conversation_id}"
)
def start_conversation(self, conversation_id: Optional[str] = None) -> str:
"""
Start a new conversation or return existing ID.
Args:
conversation_id: Optional existing conversation ID
Returns:
Conversation ID (new or existing)
"""
with self._lock:
if conversation_id is None:
conversation_id = str(uuid.uuid4())
if conversation_id not in self.conversations:
self.conversations[conversation_id] = []
self.logger.debug(f"Started new conversation: {conversation_id}")
return conversation_id
def is_processing(self, conversation_id: str) -> bool:
"""
Check if conversation is currently being processed.
Args:
conversation_id: ID of conversation
Returns:
True if currently processing, False otherwise
"""
with self._lock:
return hasattr(self, "_processing_locks") and conversation_id in getattr(
self, "_processing_locks", {}
)
def set_processing(self, conversation_id: str, processing: bool) -> None:
"""
Set processing lock for conversation.
Args:
conversation_id: ID of conversation
processing: Processing state
"""
with self._lock:
if not hasattr(self, "_processing_locks"):
self._processing_locks = {}
self._processing_locks[conversation_id] = processing
def get_conversation_turns(self, conversation_id: str) -> List[ConversationTurn]:
"""
Get all turns for a conversation.
Args:
conversation_id: ID of conversation
Returns:
List of ConversationTurn objects
"""
with self._lock:
return self.conversations.get(conversation_id, [])
def delete_conversation(self, conversation_id: str) -> bool:
"""
Delete a conversation completely.
Args:
conversation_id: ID of conversation to delete
Returns:
True if conversation was deleted, False if not found
"""
with self._lock:
if conversation_id in self.conversations:
del self.conversations[conversation_id]
self.logger.info(f"Deleted conversation: {conversation_id}")
return True
return False
def list_conversations(self) -> List[str]:
"""
List all active conversation IDs.
Returns:
List of conversation IDs
"""
with self._lock:
return list(self.conversations.keys())
def get_conversation_stats(self, conversation_id: str) -> Dict[str, Any]:
"""
Get statistics for a specific conversation.
Args:
conversation_id: ID of conversation
Returns:
Dictionary with conversation statistics
"""
with self._lock:
turns = self.conversations.get(conversation_id, [])
if not turns:
return {
"turn_count": 0,
"total_tokens": 0,
"total_response_time": 0.0,
"average_response_time": 0.0,
"average_tokens": 0.0,
}
total_tokens = sum(turn.tokens_used for turn in turns)
total_response_time = sum(turn.response_time for turn in turns)
avg_response_time = total_response_time / len(turns)
avg_tokens = total_tokens / len(turns)
return {
"turn_count": len(turns),
"total_tokens": total_tokens,
"total_response_time": total_response_time,
"average_response_time": avg_response_time,
"average_tokens": avg_tokens,
"oldest_timestamp": min(turn.timestamp for turn in turns),
"newest_timestamp": max(turn.timestamp for turn in turns),
}
def cleanup_old_conversations(self, max_age_hours: float = 24.0) -> int:
"""
Clean up conversations older than specified age.
Args:
max_age_hours: Maximum age in hours before cleanup
Returns:
Number of conversations cleaned up
"""
with self._lock:
current_time = time.time()
cutoff_time = current_time - (max_age_hours * 3600)
conversations_to_remove = []
for conv_id, turns in self.conversations.items():
if turns and turns[-1].timestamp < cutoff_time:
conversations_to_remove.append(conv_id)
for conv_id in conversations_to_remove:
del self.conversations[conv_id]
if conversations_to_remove:
self.logger.info(f"Cleaned up {len(conversations_to_remove)} old conversations")
return len(conversations_to_remove)