Initial commit: Clean slate for Mai project
This commit is contained in:
386
src/mai/conversation/state.py
Normal file
386
src/mai/conversation/state.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Conversation State Management for Mai
|
||||
|
||||
Provides turn-by-turn conversation history with proper session isolation,
|
||||
interruption handling, and context window management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
# Import existing conversation models for consistency
|
||||
try:
|
||||
from ..models.conversation import Message, Conversation
|
||||
except ImportError:
|
||||
# Fallback if models not available yet
|
||||
Message = None
|
||||
Conversation = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTurn:
|
||||
"""Single conversation turn with comprehensive metadata."""
|
||||
|
||||
conversation_id: str
|
||||
user_message: str
|
||||
ai_response: str
|
||||
timestamp: float
|
||||
model_used: str
|
||||
tokens_used: int
|
||||
response_time: float
|
||||
memory_context_applied: bool
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"conversation_id": self.conversation_id,
|
||||
"user_message": self.user_message,
|
||||
"ai_response": self.ai_response,
|
||||
"timestamp": self.timestamp,
|
||||
"model_used": self.model_used,
|
||||
"tokens_used": self.tokens_used,
|
||||
"response_time": self.response_time,
|
||||
"memory_context_applied": self.memory_context_applied,
|
||||
}
|
||||
|
||||
|
||||
class ConversationState:
|
||||
"""
|
||||
Manages conversation state across multiple sessions with proper isolation.
|
||||
|
||||
Provides turn-by-turn history tracking, automatic cleanup,
|
||||
thread-safe operations, and Ollama-compatible formatting.
|
||||
"""
|
||||
|
||||
def __init__(self, max_turns_per_conversation: int = 10):
|
||||
"""
|
||||
Initialize conversation state manager.
|
||||
|
||||
Args:
|
||||
max_turns_per_conversation: Maximum turns to keep per conversation
|
||||
"""
|
||||
self.conversations: Dict[str, List[ConversationTurn]] = {}
|
||||
self.max_turns = max_turns_per_conversation
|
||||
self._lock = threading.RLock() # Reentrant lock for nested calls
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
self.logger.info(
|
||||
f"ConversationState initialized with max {max_turns_per_conversation} turns per conversation"
|
||||
)
|
||||
|
||||
def add_turn(self, turn: ConversationTurn) -> None:
|
||||
"""
|
||||
Add a conversation turn with automatic timestamp and cleanup.
|
||||
|
||||
Args:
|
||||
turn: ConversationTurn to add
|
||||
"""
|
||||
with self._lock:
|
||||
conversation_id = turn.conversation_id
|
||||
|
||||
# Initialize conversation if doesn't exist
|
||||
if conversation_id not in self.conversations:
|
||||
self.conversations[conversation_id] = []
|
||||
self.logger.debug(f"Created new conversation: {conversation_id}")
|
||||
|
||||
# Add the turn
|
||||
self.conversations[conversation_id].append(turn)
|
||||
self.logger.debug(
|
||||
f"Added turn to conversation {conversation_id}: {turn.tokens_used} tokens, {turn.response_time:.2f}s"
|
||||
)
|
||||
|
||||
# Automatic cleanup: maintain last N turns
|
||||
if len(self.conversations[conversation_id]) > self.max_turns:
|
||||
# Remove oldest turns to maintain limit
|
||||
excess_count = len(self.conversations[conversation_id]) - self.max_turns
|
||||
removed_turns = self.conversations[conversation_id][:excess_count]
|
||||
self.conversations[conversation_id] = self.conversations[conversation_id][
|
||||
excess_count:
|
||||
]
|
||||
|
||||
self.logger.debug(
|
||||
f"Cleaned up {excess_count} old turns from conversation {conversation_id}"
|
||||
)
|
||||
|
||||
# Log removed turns for debugging
|
||||
for removed_turn in removed_turns:
|
||||
self.logger.debug(
|
||||
f"Removed turn: {removed_turn.timestamp} - {removed_turn.user_message[:50]}..."
|
||||
)
|
||||
|
||||
def get_history(self, conversation_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Get conversation history in Ollama-compatible format.
|
||||
|
||||
Args:
|
||||
conversation_id: ID of conversation to retrieve
|
||||
|
||||
Returns:
|
||||
List of message dictionaries formatted for Ollama API
|
||||
"""
|
||||
with self._lock:
|
||||
turns = self.conversations.get(conversation_id, [])
|
||||
|
||||
# Convert to Ollama format: alternating user/assistant roles
|
||||
history = []
|
||||
for turn in turns:
|
||||
history.append({"role": "user", "content": turn.user_message})
|
||||
history.append({"role": "assistant", "content": turn.ai_response})
|
||||
|
||||
self.logger.debug(
|
||||
f"Retrieved {len(history)} messages from conversation {conversation_id}"
|
||||
)
|
||||
return history
|
||||
|
||||
def set_conversation_history(
|
||||
self, messages: List[Dict[str, str]], conversation_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Restore conversation history from session storage.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries in Ollama format [{"role": "user/assistant", "content": "..."}]
|
||||
conversation_id: Optional conversation ID to restore to (creates new if None)
|
||||
"""
|
||||
with self._lock:
|
||||
if conversation_id is None:
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
# Clear existing conversation for this ID
|
||||
self.conversations[conversation_id] = []
|
||||
|
||||
# Convert messages back to ConversationTurn objects
|
||||
# Messages should be in pairs: user, assistant, user, assistant, ...
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
# Expect user message first
|
||||
if i >= len(messages) or messages[i].get("role") != "user":
|
||||
self.logger.warning(f"Expected user message at index {i}, skipping")
|
||||
i += 1
|
||||
continue
|
||||
|
||||
user_message = messages[i].get("content", "")
|
||||
i += 1
|
||||
|
||||
# Expect assistant message next
|
||||
if i >= len(messages) or messages[i].get("role") != "assistant":
|
||||
self.logger.warning(f"Expected assistant message at index {i}, skipping")
|
||||
continue
|
||||
|
||||
ai_response = messages[i].get("content", "")
|
||||
i += 1
|
||||
|
||||
# Create ConversationTurn with estimated metadata
|
||||
turn = ConversationTurn(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
ai_response=ai_response,
|
||||
timestamp=time.time(), # Use current time as approximation
|
||||
model_used="restored", # Indicate this is from restoration
|
||||
tokens_used=0, # Token count not available from session
|
||||
response_time=0.0, # Response time not available from session
|
||||
memory_context_applied=False, # Memory context not tracked in session
|
||||
)
|
||||
|
||||
self.conversations[conversation_id].append(turn)
|
||||
|
||||
self.logger.info(
|
||||
f"Restored {len(self.conversations[conversation_id])} turns to conversation {conversation_id}"
|
||||
)
|
||||
|
||||
def get_last_n_turns(self, conversation_id: str, n: int = 5) -> List[ConversationTurn]:
|
||||
"""
|
||||
Get the last N turns from a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: ID of conversation
|
||||
n: Number of recent turns to retrieve
|
||||
|
||||
Returns:
|
||||
List of last N ConversationTurn objects
|
||||
"""
|
||||
with self._lock:
|
||||
turns = self.conversations.get(conversation_id, [])
|
||||
return turns[-n:] if n > 0 else []
|
||||
|
||||
def clear_pending_response(self, conversation_id: str) -> None:
|
||||
"""
|
||||
Clear any pending response for interruption handling.
|
||||
|
||||
Args:
|
||||
conversation_id: ID of conversation to clear
|
||||
"""
|
||||
with self._lock:
|
||||
if conversation_id in self.conversations:
|
||||
# Find and remove incomplete turns (those without AI response)
|
||||
original_count = len(self.conversations[conversation_id])
|
||||
self.conversations[conversation_id] = [
|
||||
turn
|
||||
for turn in self.conversations[conversation_id]
|
||||
if turn.ai_response.strip() # Must have AI response
|
||||
]
|
||||
|
||||
removed_count = original_count - len(self.conversations[conversation_id])
|
||||
if removed_count > 0:
|
||||
self.logger.info(
|
||||
f"Cleared {removed_count} incomplete turns from conversation {conversation_id}"
|
||||
)
|
||||
|
||||
def start_conversation(self, conversation_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Start a new conversation or return existing ID.
|
||||
|
||||
Args:
|
||||
conversation_id: Optional existing conversation ID
|
||||
|
||||
Returns:
|
||||
Conversation ID (new or existing)
|
||||
"""
|
||||
with self._lock:
|
||||
if conversation_id is None:
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
if conversation_id not in self.conversations:
|
||||
self.conversations[conversation_id] = []
|
||||
self.logger.debug(f"Started new conversation: {conversation_id}")
|
||||
|
||||
return conversation_id
|
||||
|
||||
def is_processing(self, conversation_id: str) -> bool:
|
||||
"""
|
||||
Check if conversation is currently being processed.
|
||||
|
||||
Args:
|
||||
conversation_id: ID of conversation
|
||||
|
||||
Returns:
|
||||
True if currently processing, False otherwise
|
||||
"""
|
||||
with self._lock:
|
||||
return hasattr(self, "_processing_locks") and conversation_id in getattr(
|
||||
self, "_processing_locks", {}
|
||||
)
|
||||
|
||||
def set_processing(self, conversation_id: str, processing: bool) -> None:
|
||||
"""
|
||||
Set processing lock for conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: ID of conversation
|
||||
processing: Processing state
|
||||
"""
|
||||
with self._lock:
|
||||
if not hasattr(self, "_processing_locks"):
|
||||
self._processing_locks = {}
|
||||
self._processing_locks[conversation_id] = processing
|
||||
|
||||
def get_conversation_turns(self, conversation_id: str) -> List[ConversationTurn]:
|
||||
"""
|
||||
Get all turns for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: ID of conversation
|
||||
|
||||
Returns:
|
||||
List of ConversationTurn objects
|
||||
"""
|
||||
with self._lock:
|
||||
return self.conversations.get(conversation_id, [])
|
||||
|
||||
def delete_conversation(self, conversation_id: str) -> bool:
|
||||
"""
|
||||
Delete a conversation completely.
|
||||
|
||||
Args:
|
||||
conversation_id: ID of conversation to delete
|
||||
|
||||
Returns:
|
||||
True if conversation was deleted, False if not found
|
||||
"""
|
||||
with self._lock:
|
||||
if conversation_id in self.conversations:
|
||||
del self.conversations[conversation_id]
|
||||
self.logger.info(f"Deleted conversation: {conversation_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_conversations(self) -> List[str]:
|
||||
"""
|
||||
List all active conversation IDs.
|
||||
|
||||
Returns:
|
||||
List of conversation IDs
|
||||
"""
|
||||
with self._lock:
|
||||
return list(self.conversations.keys())
|
||||
|
||||
def get_conversation_stats(self, conversation_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics for a specific conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: ID of conversation
|
||||
|
||||
Returns:
|
||||
Dictionary with conversation statistics
|
||||
"""
|
||||
with self._lock:
|
||||
turns = self.conversations.get(conversation_id, [])
|
||||
|
||||
if not turns:
|
||||
return {
|
||||
"turn_count": 0,
|
||||
"total_tokens": 0,
|
||||
"total_response_time": 0.0,
|
||||
"average_response_time": 0.0,
|
||||
"average_tokens": 0.0,
|
||||
}
|
||||
|
||||
total_tokens = sum(turn.tokens_used for turn in turns)
|
||||
total_response_time = sum(turn.response_time for turn in turns)
|
||||
avg_response_time = total_response_time / len(turns)
|
||||
avg_tokens = total_tokens / len(turns)
|
||||
|
||||
return {
|
||||
"turn_count": len(turns),
|
||||
"total_tokens": total_tokens,
|
||||
"total_response_time": total_response_time,
|
||||
"average_response_time": avg_response_time,
|
||||
"average_tokens": avg_tokens,
|
||||
"oldest_timestamp": min(turn.timestamp for turn in turns),
|
||||
"newest_timestamp": max(turn.timestamp for turn in turns),
|
||||
}
|
||||
|
||||
def cleanup_old_conversations(self, max_age_hours: float = 24.0) -> int:
|
||||
"""
|
||||
Clean up conversations older than specified age.
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age in hours before cleanup
|
||||
|
||||
Returns:
|
||||
Number of conversations cleaned up
|
||||
"""
|
||||
with self._lock:
|
||||
current_time = time.time()
|
||||
cutoff_time = current_time - (max_age_hours * 3600)
|
||||
|
||||
conversations_to_remove = []
|
||||
for conv_id, turns in self.conversations.items():
|
||||
if turns and turns[-1].timestamp < cutoff_time:
|
||||
conversations_to_remove.append(conv_id)
|
||||
|
||||
for conv_id in conversations_to_remove:
|
||||
del self.conversations[conv_id]
|
||||
|
||||
if conversations_to_remove:
|
||||
self.logger.info(f"Cleaned up {len(conversations_to_remove)} old conversations")
|
||||
|
||||
return len(conversations_to_remove)
|
||||
Reference in New Issue
Block a user