Complete fresh slate

This commit is contained in:
Mai Development
2026-01-26 22:43:41 -05:00
parent 7c98aec306
commit c70ee8816e
64 changed files with 0 additions and 28116 deletions

View File

@@ -1,11 +0,0 @@
root = true
[*]
end_of_line = lf
insert_final_newline = true
charset = utf-8
trim_trailing_whitespace = true
[*.py]
indent_style = space
indent_size = 4

View File

@@ -1,171 +0,0 @@
# Memory System Configuration for Mai
# Compression settings
compression:
# Triggers for automatic compression
thresholds:
message_count: 50 # Compress after 50 messages
age_days: 30 # Compress conversations older than 30 days
memory_limit_mb: 500 # Compress when memory usage exceeds 500MB
# AI summarization configuration
summarization:
model: "llama2" # Model to use for summarization
preserve_elements: # Elements to preserve in compression
- "preferences" # User preferences and choices
- "decisions" # Important decisions made
- "patterns" # Interaction and topic patterns
- "key_facts" # Critical information and facts
min_quality_score: 0.7 # Minimum acceptable summary quality
max_summary_length: 1000 # Maximum summary length in characters
context_messages: 30 # Messages to include for context
# Adaptive weighting parameters
adaptive_weighting:
importance_decay_days: 90 # Days for importance decay
pattern_weight: 1.5 # Weight for pattern preservation
technical_weight: 1.2 # Weight for technical conversations
planning_weight: 1.3 # Weight for planning conversations
recency_boost: 1.2 # Boost for recent messages
keyword_boost: 1.5 # Boost for preference keywords
# Compression strategy settings
strategy:
keep_recent_count: 10 # Recent messages to always keep
max_patterns_extracted: 5 # Maximum patterns to extract
topic_extraction_method: "keyword" # Method for topic extraction
pattern_confidence_threshold: 0.6 # Minimum confidence for pattern extraction
# Context retrieval settings
retrieval:
# Search configuration
search:
similarity_threshold: 0.7 # Minimum similarity for semantic search
max_results: 5 # Maximum search results to return
include_content: false # Include full content in results
# Multi-faceted search weights
weights:
semantic_similarity: 0.4 # Weight for semantic similarity
keyword_match: 0.3 # Weight for keyword matching
recency_weight: 0.2 # Weight for recency
user_pattern_weight: 0.1 # Weight for user patterns
# Adaptive search settings
adaptive:
conversation_type_detection: true # Automatically detect conversation type
weight_adjustment: true # Adjust weights based on context
context_window_limit: 2000 # Token limit for context retrieval
# Performance tuning
performance:
cache_search_results: true # Cache frequent searches
cache_ttl_seconds: 300 # Cache time-to-live in seconds
parallel_search: false # Enable parallel search (experimental)
max_search_time_ms: 1000 # Maximum search time in milliseconds
# Pattern extraction settings
patterns:
# Granularity levels
extraction_granularity:
fine: # Detailed extraction for important conversations
message_sample_size: 50
pattern_confidence: 0.8
medium: # Standard extraction
message_sample_size: 30
pattern_confidence: 0.7
coarse: # Broad extraction for old conversations
message_sample_size: 20
pattern_confidence: 0.6
# Pattern types to extract
types:
user_preferences:
enabled: true
keywords:
- "prefer"
- "like"
- "want"
- "should"
- "don't like"
- "avoid"
confidence_threshold: 0.7
interaction_patterns:
enabled: true
metrics:
- "message_length_ratio"
- "response_time_pattern"
- "question_frequency"
- "clarification_requests"
topic_preferences:
enabled: true
max_topics: 10
min_topic_frequency: 3
emotional_patterns:
enabled: false # Future enhancement
sentiment_analysis: false
decision_patterns:
enabled: true
decision_keywords:
- "decided"
- "chose"
- "selected"
- "agreed"
- "rejected"
# Memory management settings
management:
# Storage limits and cleanup
storage:
max_conversation_age_days: 365 # Maximum age before review
auto_cleanup: false # Enable automatic cleanup
backup_before_cleanup: true # Backup before cleanup
# User control settings
user_control:
allow_conversation_deletion: true # Allow users to delete conversations
grace_period_days: 7 # Recovery grace period
bulk_operations: true # Allow bulk operations
# Privacy settings
privacy:
anonymize_patterns: false # Anonymize extracted patterns
pattern_retention_days: 180 # How long to keep patterns
encrypt_sensitive_topics: true # Encrypt sensitive topic patterns
# Performance and monitoring
performance:
# Resource limits
resources:
max_memory_usage_mb: 200 # Maximum memory for compression
max_cpu_usage_percent: 80 # Maximum CPU usage
max_compression_time_seconds: 30 # Maximum time per compression
# Background processing
background:
enable_background_compression: true # Run compression in background
compression_interval_hours: 6 # Check interval for compression
batch_size: 5 # Conversations per batch
# Monitoring and metrics
monitoring:
track_compression_stats: true # Track compression statistics
log_compression_events: true # Log compression operations
performance_metrics_retention_days: 30 # How long to keep metrics
# Development and debugging
debug:
# Debug settings
enabled: false # Enable debug mode
log_compression_details: false # Log detailed compression info
save_intermediate_results: false # Save intermediate compression results
# Testing settings
testing:
mock_summarization: false # Use mock summarization for testing
force_compression_threshold: false # Force compression for testing
disable_pattern_extraction: false # Disable pattern extraction for testing

View File

@@ -1,74 +0,0 @@
# Mai Sandbox Configuration
#
# This file contains all sandbox-related settings for safe code execution
# Resource Limits
resource_limits:
cpu_percent: 70 # Maximum CPU usage percentage
memory_percent: 70 # Maximum memory usage percentage
timeout_seconds: 30 # Maximum execution time in seconds
bandwidth_mbps: 50 # Maximum network bandwidth in MB/s
max_processes: 10 # Maximum number of processes
# Approval Settings
approval:
auto_approve_low_risk: true # Automatically approve low-risk operations
require_approval_high_risk: true # Always require approval for high-risk operations
remember_preferences: true # Remember user preferences for similar operations
batch_approval: true # Allow batch approval for similar operations
session_timeout: 3600 # Session timeout in seconds (1 hour)
# Risk Thresholds
risk_thresholds:
low_threshold: 0.3 # Below this is low risk
medium_threshold: 0.6 # Below this is medium risk
high_threshold: 0.8 # Below this is high risk, above is critical
# Docker Settings
docker:
image_name: "python:3.11-slim" # Docker image for code execution
network_access: false # Allow network access in sandbox
mount_points: [] # Additional mount points (empty = no mounts)
volume_size: "1G" # Maximum volume size
temp_dir: "/tmp/mai_sandbox" # Temporary directory inside container
user: "nobody" # User to run as inside container
# Audit Logging
audit:
log_level: "INFO" # Log level (DEBUG, INFO, WARNING, ERROR)
retention_days: 30 # How many days to keep logs
mask_sensitive_data: true # Mask potentially sensitive data in logs
log_file_path: ".mai/logs/audit.log" # Path to audit log file
max_log_size_mb: 100 # Maximum log file size before rotation
enable_tamper_detection: true # Enable log tamper detection
# Security Settings
security:
blocked_patterns: # Regex patterns for blocked operations
- "rm\\s+-rf\\s+/" # Dangerous delete commands
- "dd\\s+if=" # Disk imaging commands
- "format\\s+" # Disk formatting
- "fdisk" # Disk partitioning
- "mkfs" # Filesystem creation
- "chmod\\s+777" # Dangerous permission changes
quarantine_unknown: true # Quarantine unknown file types
scan_for_malware: false # Scan for malware (requires external tools)
enforce_path_restrictions: true # Restrict file system access
# Performance Settings
performance:
enable_caching: true # Enable execution result caching
cache_size_mb: 100 # Maximum cache size
enable_parallel: false # Enable parallel execution (not recommended)
max_concurrent: 1 # Maximum concurrent executions
# User Preferences (auto-populated)
user_preferences:
# Automatically populated based on user choices
# Format: operation_type: preference
# Trust Patterns (learned)
trust_patterns:
# Automatically populated based on approval history
# Format: operation_type: approval_count

View File

@@ -1,7 +0,0 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
hooks:
- id: ruff
args: ["--fix"]
- id: ruff-format

Binary file not shown.

7123
mai.log

File diff suppressed because it is too large Load Diff

View File

@@ -1,35 +0,0 @@
[project]
name = "Mai"
version = "0.1.0"
description = "GSD-native Python template"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"ollama>=0.6.1",
"psutil>=6.0.0",
"GitPython>=3.1.46",
"tiktoken>=0.8.0",
"docker>=6.0.0",
"sqlite-vec>=0.1.6",
"sentence-transformers>=3.0.0",
"blessed>=1.27.0",
"rich>=13.0.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0",
"ruff>=0.6",
"pre-commit>=3.0",
]
[tool.ruff]
line-length = 100
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "I", "B", "UP"]
ignore = []
[tool.pytest.ini_options]
testpaths = ["tests"]

View File

@@ -1,9 +0,0 @@
python -m venv .venv
.\.venv\Scripts\Activate.ps1
python -m pip install --upgrade pip
python -m pip install -e ".[dev]"
pre-commit install
Write-Host "✅ Bootstrapped (.venv created, dev deps installed)"

View File

@@ -1,15 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
PY=python
command -v python >/dev/null 2>&1 || PY=python3
$PY -m venv .venv
source .venv/bin/activate
python -m pip install --upgrade pip
python -m pip install -e ".[dev]"
pre-commit install || true
echo "✅ Bootstrapped (.venv created, dev deps installed)"

View File

@@ -1,7 +0,0 @@
.\.venv\Scripts\Activate.ps1
ruff check .
ruff format --check .
pytest -q
Write-Host "✅ Checks passed"

View File

@@ -1,10 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
source .venv/bin/activate
ruff check .
ruff format --check .
pytest -q
echo "✅ Checks passed"

View File

@@ -1,86 +0,0 @@
#!/usr/bin/env python3
from __future__ import annotations
import subprocess
from pathlib import Path
from datetime import datetime
ROOT = Path(".").resolve()
OUT = ROOT / ".planning" / "CONTEXTPACK.md"
IGNORE_DIRS = {
".git",
".venv",
"venv",
"__pycache__",
".pytest_cache",
".ruff_cache",
"dist",
"build",
"node_modules",
}
KEY_FILES = [
"CLAUDE.md",
"PROJECT.md",
"REQUIREMENTS.md",
"ROADMAP.md",
"STATE.md",
"pyproject.toml",
".pre-commit-config.yaml",
]
def run(cmd: list[str]) -> str:
try:
return subprocess.check_output(cmd, cwd=ROOT, stderr=subprocess.STDOUT, text=True).strip()
except Exception as e:
return f"(failed: {' '.join(cmd)}): {e}"
def tree(max_depth: int = 3) -> str:
lines: list[str] = []
def walk(path: Path, depth: int) -> None:
if depth > max_depth:
return
for p in sorted(path.iterdir(), key=lambda x: (x.is_file(), x.name.lower())):
if p.name in IGNORE_DIRS:
continue
rel = p.relative_to(ROOT)
indent = " " * depth
if p.is_dir():
lines.append(f"{indent}📁 {rel}/")
walk(p, depth + 1)
else:
lines.append(f"{indent}📄 {rel}")
walk(ROOT, 0)
return "\n".join(lines)
def head(path: Path, n: int = 160) -> str:
try:
return "\n".join(path.read_text(encoding="utf-8", errors="replace").splitlines()[:n])
except Exception as e:
return f"(failed reading {path}): {e}"
def main() -> None:
OUT.parent.mkdir(parents=True, exist_ok=True)
parts: list[str] = []
parts.append("# Context Pack")
parts.append(f"_Generated: {datetime.now().isoformat(timespec='seconds')}_\n")
parts.append("## Repo tree\n```text\n" + tree() + "\n```")
parts.append("## Git status\n```text\n" + run(["git", "status"]) + "\n```")
parts.append("## Recent commits\n```text\n" + run(["git", "--no-pager", "log", "-10", "--oneline"]) + "\n```")
parts.append("## Key files (head)")
for f in KEY_FILES:
p = ROOT / f
if p.exists():
parts.append(f"### {f}\n```text\n{head(p)}\n```")
OUT.write_text("\n\n".join(parts) + "\n", encoding="utf-8")
print(f"✅ Wrote {OUT.relative_to(ROOT)}")
if __name__ == "__main__":
main()

View File

@@ -1,17 +0,0 @@
Metadata-Version: 2.4
Name: Mai
Version: 0.1.0
Summary: GSD-native Python template
Requires-Python: >=3.10
Description-Content-Type: text/markdown
Requires-Dist: ollama>=0.6.1
Requires-Dist: psutil>=6.0.0
Requires-Dist: GitPython>=3.1.46
Requires-Dist: tiktoken>=0.8.0
Requires-Dist: docker>=6.0.0
Requires-Dist: sqlite-vec>=0.1.6
Requires-Dist: sentence-transformers>=3.0.0
Provides-Extra: dev
Requires-Dist: pytest>=8.0; extra == "dev"
Requires-Dist: ruff>=0.6; extra == "dev"
Requires-Dist: pre-commit>=3.0; extra == "dev"

View File

@@ -1,42 +0,0 @@
pyproject.toml
src/Mai.egg-info/PKG-INFO
src/Mai.egg-info/SOURCES.txt
src/Mai.egg-info/dependency_links.txt
src/Mai.egg-info/requires.txt
src/Mai.egg-info/top_level.txt
src/app/__init__.py
src/app/__main__.py
src/mai/core/__init__.py
src/mai/core/config.py
src/mai/core/exceptions.py
src/mai/core/interface.py
src/mai/git/__init__.py
src/mai/git/committer.py
src/mai/git/health_check.py
src/mai/git/workflow.py
src/mai/memory/__init__.py
src/mai/memory/compression.py
src/mai/memory/manager.py
src/mai/memory/retrieval.py
src/mai/memory/storage.py
src/mai/model/__init__.py
src/mai/model/compression.py
src/mai/model/ollama_client.py
src/mai/model/resource_detector.py
src/mai/model/switcher.py
src/mai/models/__init__.py
src/mai/models/conversation.py
src/mai/models/memory.py
src/mai/sandbox/__init__.py
src/mai/sandbox/approval_system.py
src/mai/sandbox/audit_logger.py
src/mai/sandbox/docker_executor.py
src/mai/sandbox/manager.py
src/mai/sandbox/resource_enforcer.py
src/mai/sandbox/risk_analyzer.py
tests/test_docker_executor.py
tests/test_docker_integration.py
tests/test_integration.py
tests/test_sandbox_approval.py
tests/test_sandbox_docker_integration.py
tests/test_smoke.py

View File

@@ -1 +0,0 @@

View File

@@ -1,12 +0,0 @@
ollama>=0.6.1
psutil>=6.0.0
GitPython>=3.1.46
tiktoken>=0.8.0
docker>=6.0.0
sqlite-vec>=0.1.6
sentence-transformers>=3.0.0
[dev]
pytest>=8.0
ruff>=0.6
pre-commit>=3.0

View File

@@ -1,2 +0,0 @@
app
mai

View File

@@ -1 +0,0 @@
__all__ = []

File diff suppressed because it is too large Load Diff

View File

@@ -1,2 +0,0 @@
19:49:18 - mai.model.ollama_client - INFO - Ollama client initialized for http://localhost:11434
19:49:18 - git.util - DEBUG - sys.platform='linux', git_executable='git'

View File

@@ -1,20 +0,0 @@
"""
Conversation Engine Module for Mai
This module provides a core conversation engine that orchestrates
multi-turn conversations with memory integration and natural timing.
"""
from .engine import ConversationEngine
from .state import ConversationState
from .timing import TimingCalculator
from .reasoning import ReasoningEngine
from .decomposition import RequestDecomposer
__all__ = [
"ConversationEngine",
"ConversationState",
"TimingCalculator",
"ReasoningEngine",
"RequestDecomposer",
]

View File

@@ -1,458 +0,0 @@
"""
Request Decomposition and Clarification Engine for Mai
Analyzes request complexity and generates appropriate clarifying questions
when user requests are ambiguous or overly complex.
"""
import logging
import re
from typing import Dict, List, Optional, Any, Tuple
logger = logging.getLogger(__name__)
class RequestDecomposer:
"""
Analyzes request complexity and generates clarifying questions.
This engine identifies ambiguous requests, assesses complexity,
and generates specific clarifying questions to improve understanding.
"""
def __init__(self):
"""Initialize request decomposer with analysis patterns."""
self.logger = logging.getLogger(__name__)
# Ambiguity patterns to detect
self._ambiguity_patterns = {
"pronouns_without_antecedents": [
r"\b(it|that|this|they|them|these|those)\b",
r"\b(he|she|it)\s+(?:is|was|were|will|would|could|should)",
],
"vague_quantifiers": [
r"\b(some|few|many|several|multiple|various|better|faster|more|less)\b",
r"\b(a bit|a little|quite|very|really|somewhat)\b",
],
"missing_context": [
r"\b(the|that|this|there)\s+(?:here|there)",
r"\b(?:from|about|regarding|concerning)\s+(?:it|that|this)",
],
"undefined_references": [
r"\b(?:fix|improve|update|change|modify)\s+(?:it|that|this)",
r"\b(?:do|make|create|build)\s+(?:it|that|this)",
],
}
# Complexity indicators
self._complexity_indicators = {
"technical_keywords": [
"function",
"algorithm",
"database",
"api",
"class",
"method",
"variable",
"loop",
"conditional",
"recursion",
"optimization",
"debug",
"implement",
"integrate",
"configure",
"deploy",
],
"multiple_tasks": [
r"\band\b",
r"\bthen\b",
r"\bafter\b",
r"\balso\b",
r"\bnext\b",
r"\bfinally\b",
r"\badditionally\b",
],
"question_density": r"[?]",
"length_threshold": 150, # characters
}
self.logger.info("RequestDecomposer initialized")
def analyze_request(self, message: str) -> Dict[str, Any]:
"""
Analyze request for complexity and ambiguity.
Args:
message: User message to analyze
Returns:
Dictionary with analysis results including:
- needs_clarification: boolean
- complexity_score: float (0-1)
- estimated_steps: int
- clarification_questions: list
- ambiguity_indicators: list
"""
message_lower = message.lower().strip()
# Detect ambiguities
ambiguity_indicators = self._detect_ambiguities(message_lower)
needs_clarification = len(ambiguity_indicators) > 0
# Calculate complexity score
complexity_score = self._calculate_complexity(message)
# Estimate steps needed
estimated_steps = self._estimate_steps(message, complexity_score)
# Generate clarification questions
clarification_questions = []
if needs_clarification:
clarification_questions = self._generate_clarifications(message, ambiguity_indicators)
return {
"needs_clarification": needs_clarification,
"complexity_score": complexity_score,
"estimated_steps": estimated_steps,
"clarification_questions": clarification_questions,
"ambiguity_indicators": ambiguity_indicators,
"message_length": len(message),
"word_count": len(message.split()),
}
def _detect_ambiguities(self, message: str) -> List[Dict[str, Any]]:
"""
Detect specific ambiguity indicators in the message.
Args:
message: Lowercase message to analyze
Returns:
List of ambiguity indicators with details
"""
ambiguities = []
for category, patterns in self._ambiguity_patterns.items():
for pattern in patterns:
matches = re.finditer(pattern, message, re.IGNORECASE)
for match in matches:
ambiguities.append(
{
"type": category,
"pattern": pattern,
"match": match.group(),
"position": match.start(),
"context": self._get_context(message, match.start(), match.end()),
}
)
return ambiguities
def _get_context(self, message: str, start: int, end: int, window: int = 20) -> str:
"""Get context around a match."""
context_start = max(0, start - window)
context_end = min(len(message), end + window)
return message[context_start:context_end]
def _calculate_complexity(self, message: str) -> float:
"""
Calculate complexity score based on multiple factors.
Args:
message: Message to analyze
Returns:
Complexity score between 0.0 (simple) and 1.0 (complex)
"""
complexity = 0.0
# Technical content (0.3 weight)
technical_count = sum(
1
for keyword in self._complexity_indicators["technical_keywords"]
if keyword.lower() in message.lower()
)
technical_score = min(technical_count * 0.1, 0.3)
complexity += technical_score
# Multiple tasks (0.25 weight)
task_matches = 0
for pattern in self._complexity_indicators["multiple_tasks"]:
matches = len(re.findall(pattern, message, re.IGNORECASE))
task_matches += matches
task_score = min(task_matches * 0.08, 0.25)
complexity += task_score
# Question density (0.2 weight)
question_count = len(re.findall(self._complexity_indicators["question_density"], message))
question_score = min(question_count * 0.05, 0.2)
complexity += question_score
# Message length (0.15 weight)
length_score = min(len(message) / 500, 0.15)
complexity += length_score
# Sentence complexity (0.1 weight)
sentences = message.split(".")
avg_sentence_length = sum(len(s.strip()) for s in sentences if s.strip()) / max(
len(sentences), 1
)
sentence_score = min(avg_sentence_length / 100, 0.1)
complexity += sentence_score
return min(complexity, 1.0)
def _estimate_steps(self, message: str, complexity_score: float) -> int:
"""
Estimate number of steps needed to fulfill request.
Args:
message: Original message
complexity_score: Calculated complexity score
Returns:
Estimated number of steps
"""
base_steps = 1
# Add steps for multiple tasks
task_count = 0
for pattern in self._complexity_indicators["multiple_tasks"]:
matches = len(re.findall(pattern, message, re.IGNORECASE))
task_count += matches
base_steps += max(0, task_count - 1) # First task is step 1
# Add steps for complexity
if complexity_score > 0.7:
base_steps += 3 # Complex requests need planning
elif complexity_score > 0.5:
base_steps += 2 # Medium complexity needs some breakdown
elif complexity_score > 0.3:
base_steps += 1 # Slightly complex might need clarification
return max(1, base_steps)
def _generate_clarifications(
self, message: str, ambiguity_indicators: List[Dict[str, Any]]
) -> List[str]:
"""
Generate specific clarifying questions for detected ambiguities.
Args:
message: Original message
ambiguity_indicators: List of detected ambiguities
Returns:
List of clarifying questions
"""
questions = []
seen_types = set()
for indicator in ambiguity_indicators:
ambiguity_type = indicator["type"]
match = indicator["match"]
# Avoid duplicate questions for same ambiguity type
if ambiguity_type in seen_types:
continue
seen_types.add(ambiguity_type)
if ambiguity_type == "pronouns_without_antecedents":
if match.lower() in ["it", "that", "this"]:
questions.append(f"Could you clarify what '{match}' refers to specifically?")
elif match.lower() in ["they", "them", "these", "those"]:
questions.append(f"Could you specify who or what '{match}' refers to?")
elif ambiguity_type == "vague_quantifiers":
if match.lower() in ["better", "faster", "more", "less"]:
questions.append(f"Could you quantify what '{match}' means in this context?")
elif match.lower() in ["some", "few", "many", "several"]:
questions.append(
f"Could you provide a specific number or amount instead of '{match}'?"
)
else:
questions.append(f"Could you be more specific about what '{match}' means?")
elif ambiguity_type == "missing_context":
questions.append(f"Could you provide more context about what '{match}' refers to?")
elif ambiguity_type == "undefined_references":
questions.append(f"Could you clarify what you'd like me to {match} specifically?")
return questions
def suggest_breakdown(
self,
message: str,
complexity_score: float,
ollama_client=None,
current_model: str = "default",
) -> Dict[str, Any]:
"""
Suggest logical breakdown for complex requests.
Args:
message: Original user message
complexity_score: Calculated complexity
ollama_client: Optional OllamaClient for semantic analysis
current_model: Current model name
Returns:
Dictionary with breakdown suggestions
"""
estimated_steps = self._estimate_steps(message, complexity_score)
# Extract potential tasks from message
tasks = self._extract_tasks(message)
breakdown = {
"estimated_steps": estimated_steps,
"complexity_level": self._get_complexity_level(complexity_score),
"suggested_approach": [],
"potential_tasks": tasks,
"effort_estimate": self._estimate_effort(complexity_score),
}
# Generate approach suggestions
if complexity_score > 0.6:
breakdown["suggested_approach"].append(
"Start by clarifying requirements and breaking into smaller tasks"
)
breakdown["suggested_approach"].append(
"Consider if this needs to be done in sequence or can be parallelized"
)
elif complexity_score > 0.3:
breakdown["suggested_approach"].append(
"Break down into logical sub-tasks before starting"
)
# Use semantic analysis if available and request is very complex
if ollama_client and complexity_score > 0.7:
try:
semantic_breakdown = self._semantic_breakdown(message, ollama_client, current_model)
breakdown["semantic_analysis"] = semantic_breakdown
except Exception as e:
self.logger.warning(f"Semantic breakdown failed: {e}")
return breakdown
def _extract_tasks(self, message: str) -> List[str]:
"""Extract potential tasks from message."""
# Simple task extraction based on verbs and patterns
task_patterns = [
r"(?:please\s+)?(?:can\s+you\s+)?(\w+)\s+(.+?)(?:\s+(?:and|then|after)\s+|$)",
r"(?:I\s+need|want)\s+(?:you\s+to\s+)?(.+?)(?:\s+(?:and|then|after)\s+|$)",
r"(?:help\s+me\s+)?(\w+)\s+(.+?)(?:\s+(?:and|then|after)\s+|$)",
]
tasks = []
for pattern in task_patterns:
matches = re.findall(pattern, message, re.IGNORECASE)
for match in matches:
if isinstance(match, tuple):
# Take the verb/object combination
task = " ".join(filter(None, match))
else:
task = str(match)
if len(task.strip()) > 3: # Filter out very short matches
tasks.append(task.strip())
return list(set(tasks)) # Remove duplicates
def _get_complexity_level(self, score: float) -> str:
"""Convert complexity score to human-readable level."""
if score >= 0.7:
return "High"
elif score >= 0.4:
return "Medium"
else:
return "Low"
def _estimate_effort(self, complexity_score: float) -> str:
"""Estimate effort based on complexity."""
if complexity_score >= 0.7:
return "Significant - may require multiple iterations"
elif complexity_score >= 0.4:
return "Moderate - should be straightforward with some planning"
else:
return "Minimal - should be quick to implement"
def _semantic_breakdown(self, message: str, ollama_client, current_model: str) -> str:
"""
Use AI to perform semantic breakdown of complex request.
Args:
message: User message to analyze
ollama_client: OllamaClient instance
current_model: Current model name
Returns:
AI-generated breakdown suggestions
"""
semantic_prompt = f"""
Analyze this complex request and suggest a logical breakdown: "{message}"
Provide a structured approach:
1. Identify the main objectives
2. Break down into logical steps
3. Note any dependencies or prerequisites
4. Suggest an order of execution
Keep it concise and actionable.
"""
try:
response = ollama_client.generate_response(semantic_prompt, current_model, [])
return self._clean_semantic_output(response)
except Exception as e:
self.logger.error(f"Semantic breakdown failed: {e}")
return "Unable to generate semantic breakdown"
def _clean_semantic_output(self, output: str) -> str:
"""Clean semantic breakdown output."""
# Remove common AI response prefixes
prefixes_to_remove = [
"Here's a breakdown:",
"Let me break this down:",
"I would approach this by:",
"Here's how I would break this down:",
]
for prefix in prefixes_to_remove:
if output.startswith(prefix):
output = output[len(prefix) :].strip()
break
return output
def get_analysis_summary(self, analysis: Dict[str, Any]) -> str:
"""
Get human-readable summary of request analysis.
Args:
analysis: Result from analyze_request()
Returns:
Formatted summary string
"""
summary_parts = []
if analysis["needs_clarification"]:
summary_parts.append("🤔 **Needs Clarification**")
summary_parts.append(f"- Questions: {len(analysis['clarification_questions'])}")
else:
summary_parts.append("✅ **Clear Request**")
complexity_level = self._get_complexity_level(analysis["complexity_score"])
summary_parts.append(
f"📊 **Complexity**: {complexity_level} ({analysis['complexity_score']:.2f})"
)
summary_parts.append(f"📋 **Estimated Steps**: {analysis['estimated_steps']}")
if analysis["ambiguity_indicators"]:
summary_parts.append(
f"⚠️ **Ambiguities Found**: {len(analysis['ambiguity_indicators'])}"
)
return "\n".join(summary_parts)

View File

@@ -1,648 +0,0 @@
"""
Core Conversation Engine for Mai
This module provides the main conversation engine that orchestrates
multi-turn conversations with memory integration and natural timing.
"""
import logging
import time
import uuid
from typing import Dict, List, Optional, Any, Tuple
from threading import Thread, Event
from dataclasses import dataclass
from ..core.interface import MaiInterface
from ..memory.manager import MemoryManager
from ..models.conversation import Conversation as ModelConversation, Message
from .state import ConversationState, ConversationTurn
from .timing import TimingCalculator
from .reasoning import ReasoningEngine
from .decomposition import RequestDecomposer
from .interruption import InterruptHandler, TurnType
logger = logging.getLogger(__name__)
@dataclass
class ConversationResponse:
"""Response from conversation processing with metadata."""
response: str
model_used: str
tokens_used: int
response_time: float
memory_context_used: int
timing_category: str
conversation_id: str
interruption_handled: bool = False
memory_integrated: bool = False
class ConversationEngine:
"""
Main conversation engine orchestrating multi-turn conversations.
Integrates memory context retrieval, natural timing calculation,
reasoning transparency, request decomposition, interruption handling,
personality consistency, and conversation state management.
"""
def __init__(
self,
mai_interface: Optional[MaiInterface] = None,
memory_manager: Optional[MemoryManager] = None,
timing_profile: str = "default",
debug_mode: bool = False,
enable_metrics: bool = True,
):
"""
Initialize conversation engine with all subsystems.
Args:
mai_interface: MaiInterface for model interaction
memory_manager: MemoryManager for context management
timing_profile: Timing profile ("default", "fast", "slow")
debug_mode: Enable debug logging and verbose output
enable_metrics: Enable performance metrics collection
"""
self.logger = logging.getLogger(__name__)
# Configuration
self.timing_profile = timing_profile
self.debug_mode = debug_mode
self.enable_metrics = enable_metrics
# Initialize components
self.mai_interface = mai_interface or MaiInterface()
self.memory_manager = memory_manager or MemoryManager()
# Conversation state management
self.conversation_state = ConversationState()
# Timing calculator for natural delays
self.timing_calculator = TimingCalculator(profile=timing_profile)
# Reasoning engine for step-by-step explanations
self.reasoning_engine = ReasoningEngine()
# Request decomposer for complex request analysis
self.request_decomposer = RequestDecomposer()
# Interruption handler for graceful recovery
self.interrupt_handler = InterruptHandler()
# Link conversation state with interrupt handler
self.interrupt_handler.set_conversation_state(self.conversation_state)
# Processing state for thread safety
self.processing_threads: Dict[str, Thread] = {}
self.interruption_events: Dict[str, Event] = {}
self.current_processing: Dict[str, bool] = {}
# Performance tracking
self.total_conversations = 0
self.total_interruptions = 0
self.start_time = time.time()
self.logger.info(
f"ConversationEngine initialized with timing_profile='{timing_profile}', debug={debug_mode}"
)
def process_turn(
self, user_message: str, conversation_id: Optional[str] = None
) -> ConversationResponse:
"""
Process a single conversation turn with complete subsystem integration.
Args:
user_message: User's input message
conversation_id: Optional conversation ID for continuation
Returns:
ConversationResponse with generated response and metadata
"""
start_time = time.time()
# Start or get conversation
if conversation_id is None:
conversation_id = self.conversation_state.start_conversation()
else:
conversation_id = self.conversation_state.start_conversation(conversation_id)
# Handle interruption if already processing
if self.conversation_state.is_processing(conversation_id):
return self._handle_interruption(conversation_id, user_message, start_time)
# Set processing lock
self.conversation_state.set_processing(conversation_id, True)
self.current_processing[conversation_id] = True
try:
self.logger.info(f"Processing conversation turn for {conversation_id}")
# Check for reasoning request
is_reasoning_request = self.reasoning_engine.is_reasoning_request(user_message)
# Analyze request complexity and decomposition needs
request_analysis = self.request_decomposer.analyze_request(user_message)
# Handle clarification needs if request is ambiguous
if request_analysis["needs_clarification"] and not is_reasoning_request:
clarification_response = self._generate_clarification_response(request_analysis)
return ConversationResponse(
response=clarification_response,
model_used="clarification",
tokens_used=0,
response_time=time.time() - start_time,
memory_context_used=0,
timing_category="clarification",
conversation_id=conversation_id,
interruption_handled=False,
memory_integrated=False,
)
# Retrieve memory context with 1000 token budget
memory_context = self._retrieve_memory_context(user_message)
# Build conversation history from state (last 10 turns)
conversation_history = self.conversation_state.get_history(conversation_id)
# Build memory-augmented prompt
augmented_prompt = self._build_augmented_prompt(
user_message, memory_context, conversation_history
)
# Calculate natural response delay based on cognitive load
context_complexity = len(str(memory_context)) if memory_context else 0
response_delay = self.timing_calculator.calculate_response_delay(
user_message, context_complexity
)
# Apply natural delay for human-like interaction
if not self.debug_mode:
self.logger.info(f"Applying {response_delay:.2f}s delay for natural timing")
time.sleep(response_delay)
# Generate response with optional reasoning
if is_reasoning_request:
# Use reasoning engine for reasoning requests
current_model = getattr(self.mai_interface, "current_model", "unknown")
if current_model is None:
current_model = "unknown"
reasoning_response = self.reasoning_engine.generate_response_with_reasoning(
user_message,
self.mai_interface.ollama_client,
current_model,
conversation_history,
)
interface_response = {
"response": reasoning_response["response"],
"model_used": reasoning_response["model_used"],
"tokens": reasoning_response.get("tokens_used", 0),
"response_time": response_delay,
}
else:
# Standard response generation
interface_response = self.mai_interface.send_message(
user_message, conversation_history
)
# Extract response details
ai_response = interface_response.get(
"response", "I apologize, but I couldn't generate a response."
)
model_used = interface_response.get("model_used", "unknown")
tokens_used = interface_response.get("tokens", 0)
# Store conversation turn in memory
self._store_conversation_turn(
conversation_id, user_message, ai_response, interface_response
)
# Create conversation turn with all metadata
turn = ConversationTurn(
conversation_id=conversation_id,
user_message=user_message,
ai_response=ai_response,
timestamp=start_time,
model_used=model_used,
tokens_used=tokens_used,
response_time=response_delay,
memory_context_applied=bool(memory_context),
)
# Add turn to conversation state
self.conversation_state.add_turn(turn)
# Calculate response time and timing category
total_response_time = time.time() - start_time
complexity_score = self.timing_calculator.get_complexity_score(
user_message, context_complexity
)
if complexity_score < 0.3:
timing_category = "simple"
elif complexity_score < 0.7:
timing_category = "medium"
else:
timing_category = "complex"
# Create comprehensive response object
response = ConversationResponse(
response=ai_response,
model_used=model_used,
tokens_used=tokens_used,
response_time=total_response_time,
memory_context_used=len(memory_context) if memory_context else 0,
timing_category=timing_category,
conversation_id=conversation_id,
memory_integrated=bool(memory_context),
interruption_handled=False,
)
self.total_conversations += 1
self.logger.info(f"Conversation turn completed for {conversation_id}")
return response
except Exception as e:
return ConversationResponse(
response=f"I understand you want to move on. Let me help you with that.",
model_used="error",
tokens_used=0,
response_time=time.time() - start_time,
memory_context_used=0,
timing_category="interruption",
conversation_id=conversation_id,
interruption_handled=True,
memory_integrated=False,
)
def _generate_clarification_response(self, request_analysis: Dict[str, Any]) -> str:
"""
Generate clarifying response for ambiguous requests.
Args:
request_analysis: Analysis from RequestDecomposer
Returns:
Clarifying response string
"""
questions = request_analysis.get("clarification_questions", [])
if not questions:
return "Could you please provide more details about your request?"
response_parts = ["I need some clarification to help you better:"]
for i, question in enumerate(questions, 1):
response_parts.append(f"{i}. {question}")
response_parts.append("\nPlease provide the missing information and I'll be happy to help!")
return "\n".join(response_parts)
def _retrieve_memory_context(self, user_message: str) -> Optional[Dict[str, Any]]:
"""
Retrieve relevant memory context for user message.
Uses 1000 token budget as specified in requirements.
"""
try:
if not self.memory_manager:
return None
# Get context with 1000 token budget and 3 max results
context = self.memory_manager.get_context(
query=user_message, max_tokens=1000, max_results=3
)
self.logger.debug(
f"Retrieved {len(context.get('relevant_conversations', []))} relevant conversations"
)
return context
except Exception as e:
self.logger.warning(f"Failed to retrieve memory context: {e}")
return None
def _build_augmented_prompt(
self,
user_message: str,
memory_context: Optional[Dict[str, Any]],
conversation_history: List[Dict[str, str]],
) -> str:
"""
Build memory-augmented prompt for model interaction.
Integrates context and history as specified in requirements.
"""
prompt_parts = []
# Add memory context if available
if memory_context and memory_context.get("relevant_conversations"):
context_text = "Context from previous conversations:\n"
for conv in memory_context["relevant_conversations"][:2]: # Limit to 2 most relevant
context_text += f"- {conv['title']}: {conv['excerpt']}\n"
prompt_parts.append(context_text)
# Add conversation history
if conversation_history:
history_text = "\nRecent conversation:\n"
for msg in conversation_history[-10:]: # Last 10 turns
role = msg["role"]
content = msg["content"][:200] # Truncate long messages
history_text += f"{role}: {content}\n"
prompt_parts.append(history_text)
# Add current user message
prompt_parts.append(f"User: {user_message}")
return "\n\n".join(prompt_parts)
def _store_conversation_turn(
self,
conversation_id: str,
user_message: str,
ai_response: str,
interface_response: Dict[str, Any],
) -> None:
"""
Store conversation turn in memory using MemoryManager.
Creates structured conversation data for persistence.
"""
try:
if not self.memory_manager:
return
# Build conversation messages for storage
conversation_messages = []
# Add context and history if available
if interface_response.get("memory_context_used", 0) > 0:
memory_context_msg = {
"role": "system",
"content": "Using memory context from previous conversations",
}
conversation_messages.append(memory_context_msg)
# Add current turn
conversation_messages.extend(
[
{"role": "user", "content": user_message},
{"role": "assistant", "content": ai_response},
]
)
# Store in memory with metadata
turn_metadata = {
"conversation_id": conversation_id,
"model_used": interface_response.get("model_used", "unknown"),
"response_time": interface_response.get("response_time", 0),
"tokens": interface_response.get("tokens", 0),
"memory_context_applied": interface_response.get("memory_context_used", 0) > 0,
"timestamp": time.time(),
"engine_version": "conversation-engine-v1",
}
conv_id = self.memory_manager.store_conversation(
messages=conversation_messages, metadata=turn_metadata
)
self.logger.debug(f"Stored conversation turn in memory: {conv_id}")
except Exception as e:
self.logger.warning(f"Failed to store conversation turn: {e}")
def _handle_interruption(
self, conversation_id: str, new_message: str, start_time: float
) -> ConversationResponse:
"""
Handle user interruption during processing.
Clears pending response and restarts with new context using InterruptHandler.
"""
self.logger.info(f"Handling interruption for conversation {conversation_id}")
self.total_interruptions += 1
# Create interruption context
interrupt_context = self.interrupt_handler.interrupt_and_restart(
new_message=new_message,
conversation_id=conversation_id,
turn_type=TurnType.USER_INPUT,
reason="user_input",
)
# Restart processing with new message (immediate response for interruption)
try:
interface_response = self.mai_interface.send_message(
new_message, self.conversation_state.get_history(conversation_id)
)
return ConversationResponse(
response=interface_response.get(
"response", "I understand you want to move on. How can I help you?"
),
model_used=interface_response.get("model_used", "unknown"),
tokens_used=interface_response.get("tokens", 0),
response_time=time.time() - start_time,
memory_context_used=0,
timing_category="interruption",
conversation_id=conversation_id,
interruption_handled=True,
memory_integrated=False,
)
except Exception as e:
return ConversationResponse(
response=f"I understand you want to move on. Let me help you with that.",
model_used="error",
tokens_used=0,
response_time=time.time() - start_time,
memory_context_used=0,
timing_category="interruption",
conversation_id=conversation_id,
interruption_handled=True,
memory_integrated=False,
)
def get_conversation_history(
self, conversation_id: str, limit: int = 10
) -> List[ConversationTurn]:
"""Get conversation history for a specific conversation."""
return self.conversation_state.get_conversation_turns(conversation_id)[-limit:]
def get_engine_stats(self) -> Dict[str, Any]:
"""Get engine performance statistics."""
uptime = time.time() - self.start_time
return {
"uptime_seconds": uptime,
"total_conversations": self.total_conversations,
"total_interruptions": self.total_interruptions,
"active_conversations": len(self.conversation_state.conversations),
"average_response_time": 0.0, # Would be calculated from actual responses
"memory_integration_rate": 0.0, # Would be calculated from actual responses
}
def calculate_response_delay(
self, user_message: str, context_complexity: Optional[int] = None
) -> float:
"""
Calculate natural response delay using TimingCalculator.
Args:
user_message: User message to analyze
context_complexity: Optional context complexity
Returns:
Response delay in seconds
"""
return self.timing_calculator.calculate_response_delay(user_message, context_complexity)
def is_reasoning_request(self, user_message: str) -> bool:
"""
Check if user is requesting reasoning explanation.
Args:
user_message: User message to analyze
Returns:
True if this appears to be a reasoning request
"""
return self.reasoning_engine.is_reasoning_request(user_message)
def generate_response_with_reasoning(
self, user_message: str, conversation_history: List[Dict[str, str]]
) -> Dict[str, Any]:
"""
Generate response with step-by-step reasoning explanation.
Args:
user_message: Original user message
conversation_history: Conversation context
Returns:
Dictionary with reasoning-enhanced response
"""
current_model = getattr(self.mai_interface, "current_model", "unknown")
if current_model is None:
current_model = "unknown"
return self.reasoning_engine.generate_response_with_reasoning(
user_message, self.mai_interface.ollama_client, current_model, conversation_history
)
def analyze_request_complexity(self, user_message: str) -> Dict[str, Any]:
"""
Analyze request complexity and decomposition needs.
Args:
user_message: User message to analyze
Returns:
Request analysis dictionary
"""
return self.request_decomposer.analyze_request(user_message)
def check_interruption(self, conversation_id: str) -> bool:
"""
Check if interruption has occurred for a conversation.
Args:
conversation_id: ID of conversation to check
Returns:
True if interruption detected
"""
return self.interrupt_handler.check_interruption(conversation_id)
def interrupt_and_restart(
self, new_message: str, conversation_id: str, reason: str = "user_input"
) -> Dict[str, Any]:
"""
Handle interruption and restart conversation.
Args:
new_message: New message that triggered interruption
conversation_id: ID of conversation
reason: Reason for interruption
Returns:
Interruption context dictionary
"""
interrupt_context = self.interrupt_handler.interrupt_and_restart(
new_message=new_message,
conversation_id=conversation_id,
turn_type=TurnType.USER_INPUT,
reason=reason,
)
return interrupt_context.to_dict()
def needs_clarification(self, request_analysis: Dict[str, Any]) -> bool:
"""
Check if request needs clarification.
Args:
request_analysis: Request analysis result
Returns:
True if clarification is needed
"""
return request_analysis.get("needs_clarification", False)
def suggest_breakdown(self, user_message: str, complexity_score: float) -> Dict[str, Any]:
"""
Suggest logical breakdown for complex requests.
Args:
user_message: Original user message
complexity_score: Complexity score from analysis
Returns:
Breakdown suggestions dictionary
"""
return self.request_decomposer.suggest_breakdown(
user_message,
complexity_score,
self.mai_interface.ollama_client,
getattr(self.mai_interface, "current_model", "default"),
)
def adapt_response_with_personality(
self, response: str, user_message: str, context_type: Optional[str] = None
) -> str:
"""
Adapt response based on personality guidelines.
Args:
response: Generated response to adapt
user_message: Original user message for context
context_type: Type of conversation context
Returns:
Personality-adapted response
"""
# For now, return original response
# Personality integration will be implemented in Phase 9
return response
def cleanup(self, max_age_hours: int = 24) -> None:
"""Clean up old conversations and resources."""
self.conversation_state.cleanup_old_conversations(max_age_hours)
self.logger.info(f"Cleaned up conversations older than {max_age_hours} hours")
def shutdown(self) -> None:
"""Shutdown conversation engine gracefully."""
self.logger.info("Shutting down ConversationEngine...")
# Cancel any processing threads
for conv_id, thread in self.processing_threads.items():
if thread.is_alive():
if conv_id in self.interruption_events:
self.interruption_events[conv_id].set()
thread.join(timeout=1.0)
# Cleanup resources
self.cleanup()
self.logger.info("ConversationEngine shutdown complete")

View File

@@ -1,333 +0,0 @@
"""
Interruption Handling for Mai Conversations
Provides graceful interruption handling during conversation processing
with thread-safe operations and conversation restart capabilities.
"""
import logging
import threading
import time
import uuid
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass
from enum import Enum
# Import conversation state for integration
try:
from .state import ConversationState
except ImportError:
# Fallback for standalone usage
ConversationState = None
logger = logging.getLogger(__name__)
class TurnType(Enum):
"""Types of conversation turns for different input sources."""
USER_INPUT = "user_input"
SELF_REFLECTION = "self_reflection"
CODE_EXECUTION = "code_execution"
SYSTEM_NOTIFICATION = "system_notification"
@dataclass
class InterruptionContext:
"""Context for conversation interruption and restart."""
interruption_id: str
original_message: str
new_message: str
conversation_id: str
turn_type: TurnType
timestamp: float
processing_time: float
reason: str = "user_input"
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"interruption_id": self.interruption_id,
"original_message": self.original_message,
"new_message": self.new_message,
"conversation_id": self.conversation_id,
"turn_type": self.turn_type.value,
"timestamp": self.timestamp,
"processing_time": self.processing_time,
"reason": self.reason,
}
class InterruptHandler:
"""
Manages graceful conversation interruptions and restarts.
Provides thread-safe interruption detection, context preservation,
and timeout-based protection for long-running operations.
"""
def __init__(self, timeout_seconds: float = 30.0):
"""
Initialize interruption handler.
Args:
timeout_seconds: Maximum processing time before auto-interruption
"""
self.timeout_seconds = timeout_seconds
self.interrupt_flag = False
self.processing_lock = threading.RLock()
self.state_lock = threading.RLock()
# Track active processing contexts
self.active_contexts: Dict[str, Dict[str, Any]] = {}
# Conversation state integration
self.conversation_state: Optional[ConversationState] = None
# Statistics
self.interruption_count = 0
self.timeout_count = 0
self.logger = logging.getLogger(__name__)
self.logger.info(f"InterruptHandler initialized with {timeout_seconds}s timeout")
def set_conversation_state(self, conversation_state: ConversationState) -> None:
"""
Set conversation state for integration.
Args:
conversation_state: ConversationState instance for context management
"""
with self.state_lock:
self.conversation_state = conversation_state
self.logger.debug("Conversation state integrated")
def start_processing(
self,
message: str,
conversation_id: str,
turn_type: TurnType = TurnType.USER_INPUT,
context: Optional[Dict[str, Any]] = None,
) -> str:
"""
Start processing a conversation turn.
Args:
message: Message being processed
conversation_id: ID of conversation
turn_type: Type of conversation turn
context: Additional processing context
Returns:
Processing context ID for tracking
"""
processing_id = str(uuid.uuid4())
start_time = time.time()
with self.processing_lock:
self.active_contexts[processing_id] = {
"message": message,
"conversation_id": conversation_id,
"turn_type": turn_type,
"context": context or {},
"start_time": start_time,
"timeout_timer": None,
}
# Reset interruption flag for new processing
self.interrupt_flag = False
self.logger.debug(
f"Started processing {processing_id}: {turn_type.value} for conversation {conversation_id}"
)
return processing_id
def check_interruption(self, processing_id: Optional[str] = None) -> bool:
"""
Check if interruption occurred during processing.
Args:
processing_id: Specific processing context to check (optional)
Returns:
True if interruption detected, False otherwise
"""
with self.processing_lock:
# Check global interruption flag
was_interrupted = self.interrupt_flag
# Check timeout for active contexts
if processing_id and processing_id in self.active_contexts:
context = self.active_contexts[processing_id]
elapsed = time.time() - context["start_time"]
if elapsed > self.timeout_seconds:
self.logger.info(f"Processing timeout for {processing_id} after {elapsed:.1f}s")
self.timeout_count += 1
was_interrupted = True
# Reset flag after checking
if was_interrupted:
self.interrupt_flag = False
self.interruption_count += 1
return was_interrupted
def interrupt_and_restart(
self,
new_message: str,
conversation_id: str,
turn_type: TurnType = TurnType.USER_INPUT,
reason: str = "user_input",
) -> InterruptionContext:
"""
Handle interruption and prepare for restart.
Args:
new_message: New message that triggered interruption
conversation_id: ID of conversation
turn_type: Type of new conversation turn
reason: Reason for interruption
Returns:
InterruptionContext with restart information
"""
interruption_id = str(uuid.uuid4())
current_time = time.time()
with self.processing_lock:
# Find the active processing context for this conversation
active_context = None
original_message = ""
processing_time = 0.0
for proc_id, context in self.active_contexts.items():
if context["conversation_id"] == conversation_id:
active_context = context
processing_time = current_time - context["start_time"]
original_message = context["message"]
break
# Set interruption flag
self.interrupt_flag = True
# Clear pending response from conversation state
if self.conversation_state:
self.conversation_state.clear_pending_response(conversation_id)
# Create interruption context
interruption_context = InterruptionContext(
interruption_id=interruption_id,
original_message=original_message,
new_message=new_message,
conversation_id=conversation_id,
turn_type=turn_type,
timestamp=current_time,
processing_time=processing_time,
reason=reason,
)
self.logger.info(
f"Interruption {interruption_id} for conversation {conversation_id}: {reason}"
)
return interruption_context
def finish_processing(self, processing_id: str) -> None:
"""
Mark processing as complete and cleanup context.
Args:
processing_id: Processing context ID to finish
"""
with self.processing_lock:
if processing_id in self.active_contexts:
context = self.active_contexts[processing_id]
elapsed = time.time() - context["start_time"]
del self.active_contexts[processing_id]
self.logger.debug(f"Finished processing {processing_id} in {elapsed:.2f}s")
def get_active_processing(self, conversation_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Get currently active processing contexts.
Args:
conversation_id: Filter by specific conversation (optional)
Returns:
List of active processing contexts
"""
with self.processing_lock:
active = []
for proc_id, context in self.active_contexts.items():
if conversation_id is None or context["conversation_id"] == conversation_id:
active_context = context.copy()
active_context["processing_id"] = proc_id
active_context["elapsed"] = time.time() - context["start_time"]
active.append(active_context)
return active
def cleanup_stale_processing(self, max_age_seconds: float = 300.0) -> int:
"""
Clean up stale processing contexts.
Args:
max_age_seconds: Maximum age before cleanup
Returns:
Number of contexts cleaned up
"""
current_time = time.time()
stale_contexts = []
with self.processing_lock:
for proc_id, context in self.active_contexts.items():
elapsed = current_time - context["start_time"]
if elapsed > max_age_seconds:
stale_contexts.append(proc_id)
for proc_id in stale_contexts:
del self.active_contexts[proc_id]
if stale_contexts:
self.logger.info(f"Cleaned up {len(stale_contexts)} stale processing contexts")
return len(stale_contexts)
def get_statistics(self) -> Dict[str, Any]:
"""
Get interruption handler statistics.
Returns:
Dictionary with performance and usage statistics
"""
with self.processing_lock:
return {
"interruption_count": self.interruption_count,
"timeout_count": self.timeout_count,
"active_processing_count": len(self.active_contexts),
"timeout_seconds": self.timeout_seconds,
"last_activity": time.time(),
}
def configure_timeout(self, timeout_seconds: float) -> None:
"""
Update timeout configuration.
Args:
timeout_seconds: New timeout value in seconds
"""
with self.state_lock:
self.timeout_seconds = max(5.0, timeout_seconds) # Minimum 5 seconds
self.logger.info(f"Timeout updated to {self.timeout_seconds}s")
def reset_statistics(self) -> None:
"""Reset interruption handler statistics."""
with self.state_lock:
self.interruption_count = 0
self.timeout_count = 0
self.logger.info("Interruption statistics reset")

View File

@@ -1,284 +0,0 @@
"""
Reasoning Transparency Engine for Mai
Provides step-by-step reasoning explanations when explicitly requested
by users, with caching for performance optimization.
"""
import logging
import hashlib
import time
from typing import Dict, List, Optional, Any, Tuple
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
class ReasoningEngine:
"""
Provides reasoning transparency and step-by-step explanations.
This engine detects when users explicitly ask for reasoning explanations
and generates detailed step-by-step breakdowns of Mai's thought process.
"""
def __init__(self):
"""Initialize reasoning engine with caching."""
self.logger = logging.getLogger(__name__)
# Cache for reasoning explanations to avoid recomputation
self._reasoning_cache: Dict[str, Dict[str, Any]] = {}
self._cache_duration = timedelta(hours=24)
# Keywords that indicate reasoning requests
self._reasoning_keywords = [
"how did you",
"explain your reasoning",
"step by step",
"why",
"process",
"how do you know",
"what makes you think",
"show your work",
"walk through",
"break down",
"explain your logic",
"how did you arrive",
"what's your reasoning",
"explain yourself",
]
self.logger.info("ReasoningEngine initialized")
def is_reasoning_request(self, message: str) -> bool:
"""
Detect when user explicitly asks for reasoning explanation.
Args:
message: User message to analyze
Returns:
True if this appears to be a reasoning request
"""
message_lower = message.lower().strip()
# Check for reasoning keywords
for keyword in self._reasoning_keywords:
if keyword in message_lower:
self.logger.debug(f"Reasoning request detected via keyword: {keyword}")
return True
# Check for question patterns asking about process
reasoning_patterns = [
r"how did you",
r"why.*you.*\?",
r"what.*your.*process",
r"can you.*explain.*your",
r"show.*your.*work",
r"explain.*how.*you",
r"what.*your.*reasoning",
r"walk.*through.*your",
]
import re
for pattern in reasoning_patterns:
if re.search(pattern, message_lower):
self.logger.debug(f"Reasoning request detected via pattern: {pattern}")
return True
return False
def _get_cache_key(self, message: str) -> str:
"""Generate cache key based on message content hash."""
return hashlib.md5(message.encode()).hexdigest()
def _is_cache_valid(self, cache_entry: Optional[Dict[str, Any]]) -> bool:
"""Check if cache entry is still valid."""
if not cache_entry:
return False
cached_time = cache_entry.get("timestamp")
if not cached_time:
return False
return datetime.now() - cached_time < self._cache_duration
def generate_response_with_reasoning(
self,
user_message: str,
ollama_client,
current_model: str,
context: Optional[List[Dict[str, Any]]] = None,
show_reasoning: bool = False,
) -> Dict[str, Any]:
"""
Generate response with optional step-by-step reasoning explanation.
Args:
user_message: Original user message
ollama_client: OllamaClient instance for generating responses
current_model: Current model name
context: Conversation context
show_reasoning: Whether to include reasoning explanation
Returns:
Dictionary with response, reasoning (if requested), and metadata
"""
# Check cache first
cache_key = self._get_cache_key(user_message)
cached_entry = self._reasoning_cache.get(cache_key)
if (
cached_entry
and self._is_cache_valid(cached_entry)
and cached_entry.get("message") == user_message
):
self.logger.debug("Using cached reasoning response")
return cached_entry["response"]
# Detect if this is a reasoning request
is_reasoning = show_reasoning or self.is_reasoning_request(user_message)
try:
# Generate standard response
standard_response = ollama_client.generate_response(
user_message, current_model, context or []
)
response_data = {
"response": standard_response,
"model_used": current_model,
"show_reasoning": is_reasoning,
"reasoning": None,
"format": "standard",
}
# Generate reasoning explanation if requested
if is_reasoning:
reasoning = self._generate_reasoning_explanation(
user_message, standard_response, ollama_client, current_model
)
response_data["reasoning"] = reasoning
response_data["format"] = "with_reasoning"
# Format response with reasoning
formatted_response = self.format_reasoning_response(standard_response, reasoning)
response_data["response"] = formatted_response
# Cache the response
self._reasoning_cache[cache_key] = {
"message": user_message,
"response": response_data,
"timestamp": datetime.now(),
}
self.logger.info(f"Generated response with reasoning={is_reasoning}")
return response_data
except Exception as e:
self.logger.error(f"Failed to generate response with reasoning: {e}")
raise
def _generate_reasoning_explanation(
self, user_message: str, standard_response: str, ollama_client, current_model: str
) -> str:
"""
Generate step-by-step reasoning explanation.
Args:
user_message: Original user question
standard_response: The response that was generated
ollama_client: OllamaClient for generating reasoning
current_model: Current model name
Returns:
Formatted reasoning explanation as numbered steps
"""
reasoning_prompt = f"""
Explain your reasoning step by step for answering: "{user_message}"
Your final answer was: "{standard_response}"
Please explain your reasoning process:
1. Start by understanding what the user is asking
2. Break down the key components of the question
3. Explain your thought process step by step
4. Show how you arrived at your conclusion
5. End with "Final answer:" followed by your actual response
Format as clear numbered steps. Be detailed but concise.
"""
try:
reasoning = ollama_client.generate_response(reasoning_prompt, current_model, [])
return self._clean_reasoning_output(reasoning)
except Exception as e:
self.logger.error(f"Failed to generate reasoning explanation: {e}")
return f"I apologize, but I encountered an error generating my reasoning explanation. My response was: {standard_response}"
def _clean_reasoning_output(self, reasoning: str) -> str:
"""Clean and format reasoning output."""
# Remove any redundant prefixes
reasoning = reasoning.strip()
# Remove common AI response prefixes
prefixes_to_remove = [
"Here's my reasoning:",
"My reasoning is:",
"Let me explain my reasoning:",
"I'll explain my reasoning step by step:",
]
for prefix in prefixes_to_remove:
if reasoning.startswith(prefix):
reasoning = reasoning[len(prefix) :].strip()
break
return reasoning
def format_reasoning_response(self, response: str, reasoning: str) -> str:
"""
Format reasoning with clear separation from main answer.
Args:
response: The actual response
reasoning: The reasoning explanation
Returns:
Formatted response with reasoning section
"""
# Clean up any existing formatting
reasoning = self._clean_reasoning_output(reasoning)
# Format with clear separation
formatted = f"""## 🧠 My Reasoning Process
{reasoning}
---
## 💬 My Response
{response}"""
return formatted
def clear_cache(self) -> None:
"""Clear reasoning cache."""
self._reasoning_cache.clear()
self.logger.info("Reasoning cache cleared")
def get_cache_stats(self) -> Dict[str, Any]:
"""Get reasoning cache statistics."""
total_entries = len(self._reasoning_cache)
valid_entries = sum(
1 for entry in self._reasoning_cache.values() if self._is_cache_valid(entry)
)
return {
"total_entries": total_entries,
"valid_entries": valid_entries,
"cache_duration_hours": self._cache_duration.total_seconds() / 3600,
"last_cleanup": datetime.now().isoformat(),
}

View File

@@ -1,386 +0,0 @@
"""
Conversation State Management for Mai
Provides turn-by-turn conversation history with proper session isolation,
interruption handling, and context window management.
"""
import logging
import time
import threading
import uuid
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from datetime import datetime
# Import existing conversation models for consistency
try:
from ..models.conversation import Message, Conversation
except ImportError:
# Fallback if models not available yet
Message = None
Conversation = None
logger = logging.getLogger(__name__)
@dataclass
class ConversationTurn:
"""Single conversation turn with comprehensive metadata."""
conversation_id: str
user_message: str
ai_response: str
timestamp: float
model_used: str
tokens_used: int
response_time: float
memory_context_applied: bool
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"conversation_id": self.conversation_id,
"user_message": self.user_message,
"ai_response": self.ai_response,
"timestamp": self.timestamp,
"model_used": self.model_used,
"tokens_used": self.tokens_used,
"response_time": self.response_time,
"memory_context_applied": self.memory_context_applied,
}
class ConversationState:
"""
Manages conversation state across multiple sessions with proper isolation.
Provides turn-by-turn history tracking, automatic cleanup,
thread-safe operations, and Ollama-compatible formatting.
"""
def __init__(self, max_turns_per_conversation: int = 10):
"""
Initialize conversation state manager.
Args:
max_turns_per_conversation: Maximum turns to keep per conversation
"""
self.conversations: Dict[str, List[ConversationTurn]] = {}
self.max_turns = max_turns_per_conversation
self._lock = threading.RLock() # Reentrant lock for nested calls
self.logger = logging.getLogger(__name__)
self.logger.info(
f"ConversationState initialized with max {max_turns_per_conversation} turns per conversation"
)
def add_turn(self, turn: ConversationTurn) -> None:
"""
Add a conversation turn with automatic timestamp and cleanup.
Args:
turn: ConversationTurn to add
"""
with self._lock:
conversation_id = turn.conversation_id
# Initialize conversation if doesn't exist
if conversation_id not in self.conversations:
self.conversations[conversation_id] = []
self.logger.debug(f"Created new conversation: {conversation_id}")
# Add the turn
self.conversations[conversation_id].append(turn)
self.logger.debug(
f"Added turn to conversation {conversation_id}: {turn.tokens_used} tokens, {turn.response_time:.2f}s"
)
# Automatic cleanup: maintain last N turns
if len(self.conversations[conversation_id]) > self.max_turns:
# Remove oldest turns to maintain limit
excess_count = len(self.conversations[conversation_id]) - self.max_turns
removed_turns = self.conversations[conversation_id][:excess_count]
self.conversations[conversation_id] = self.conversations[conversation_id][
excess_count:
]
self.logger.debug(
f"Cleaned up {excess_count} old turns from conversation {conversation_id}"
)
# Log removed turns for debugging
for removed_turn in removed_turns:
self.logger.debug(
f"Removed turn: {removed_turn.timestamp} - {removed_turn.user_message[:50]}..."
)
def get_history(self, conversation_id: str) -> List[Dict[str, str]]:
"""
Get conversation history in Ollama-compatible format.
Args:
conversation_id: ID of conversation to retrieve
Returns:
List of message dictionaries formatted for Ollama API
"""
with self._lock:
turns = self.conversations.get(conversation_id, [])
# Convert to Ollama format: alternating user/assistant roles
history = []
for turn in turns:
history.append({"role": "user", "content": turn.user_message})
history.append({"role": "assistant", "content": turn.ai_response})
self.logger.debug(
f"Retrieved {len(history)} messages from conversation {conversation_id}"
)
return history
def set_conversation_history(
self, messages: List[Dict[str, str]], conversation_id: Optional[str] = None
) -> None:
"""
Restore conversation history from session storage.
Args:
messages: List of message dictionaries in Ollama format [{"role": "user/assistant", "content": "..."}]
conversation_id: Optional conversation ID to restore to (creates new if None)
"""
with self._lock:
if conversation_id is None:
conversation_id = str(uuid.uuid4())
# Clear existing conversation for this ID
self.conversations[conversation_id] = []
# Convert messages back to ConversationTurn objects
# Messages should be in pairs: user, assistant, user, assistant, ...
i = 0
while i < len(messages):
# Expect user message first
if i >= len(messages) or messages[i].get("role") != "user":
self.logger.warning(f"Expected user message at index {i}, skipping")
i += 1
continue
user_message = messages[i].get("content", "")
i += 1
# Expect assistant message next
if i >= len(messages) or messages[i].get("role") != "assistant":
self.logger.warning(f"Expected assistant message at index {i}, skipping")
continue
ai_response = messages[i].get("content", "")
i += 1
# Create ConversationTurn with estimated metadata
turn = ConversationTurn(
conversation_id=conversation_id,
user_message=user_message,
ai_response=ai_response,
timestamp=time.time(), # Use current time as approximation
model_used="restored", # Indicate this is from restoration
tokens_used=0, # Token count not available from session
response_time=0.0, # Response time not available from session
memory_context_applied=False, # Memory context not tracked in session
)
self.conversations[conversation_id].append(turn)
self.logger.info(
f"Restored {len(self.conversations[conversation_id])} turns to conversation {conversation_id}"
)
def get_last_n_turns(self, conversation_id: str, n: int = 5) -> List[ConversationTurn]:
"""
Get the last N turns from a conversation.
Args:
conversation_id: ID of conversation
n: Number of recent turns to retrieve
Returns:
List of last N ConversationTurn objects
"""
with self._lock:
turns = self.conversations.get(conversation_id, [])
return turns[-n:] if n > 0 else []
def clear_pending_response(self, conversation_id: str) -> None:
"""
Clear any pending response for interruption handling.
Args:
conversation_id: ID of conversation to clear
"""
with self._lock:
if conversation_id in self.conversations:
# Find and remove incomplete turns (those without AI response)
original_count = len(self.conversations[conversation_id])
self.conversations[conversation_id] = [
turn
for turn in self.conversations[conversation_id]
if turn.ai_response.strip() # Must have AI response
]
removed_count = original_count - len(self.conversations[conversation_id])
if removed_count > 0:
self.logger.info(
f"Cleared {removed_count} incomplete turns from conversation {conversation_id}"
)
def start_conversation(self, conversation_id: Optional[str] = None) -> str:
"""
Start a new conversation or return existing ID.
Args:
conversation_id: Optional existing conversation ID
Returns:
Conversation ID (new or existing)
"""
with self._lock:
if conversation_id is None:
conversation_id = str(uuid.uuid4())
if conversation_id not in self.conversations:
self.conversations[conversation_id] = []
self.logger.debug(f"Started new conversation: {conversation_id}")
return conversation_id
def is_processing(self, conversation_id: str) -> bool:
"""
Check if conversation is currently being processed.
Args:
conversation_id: ID of conversation
Returns:
True if currently processing, False otherwise
"""
with self._lock:
return hasattr(self, "_processing_locks") and conversation_id in getattr(
self, "_processing_locks", {}
)
def set_processing(self, conversation_id: str, processing: bool) -> None:
"""
Set processing lock for conversation.
Args:
conversation_id: ID of conversation
processing: Processing state
"""
with self._lock:
if not hasattr(self, "_processing_locks"):
self._processing_locks = {}
self._processing_locks[conversation_id] = processing
def get_conversation_turns(self, conversation_id: str) -> List[ConversationTurn]:
"""
Get all turns for a conversation.
Args:
conversation_id: ID of conversation
Returns:
List of ConversationTurn objects
"""
with self._lock:
return self.conversations.get(conversation_id, [])
def delete_conversation(self, conversation_id: str) -> bool:
"""
Delete a conversation completely.
Args:
conversation_id: ID of conversation to delete
Returns:
True if conversation was deleted, False if not found
"""
with self._lock:
if conversation_id in self.conversations:
del self.conversations[conversation_id]
self.logger.info(f"Deleted conversation: {conversation_id}")
return True
return False
def list_conversations(self) -> List[str]:
"""
List all active conversation IDs.
Returns:
List of conversation IDs
"""
with self._lock:
return list(self.conversations.keys())
def get_conversation_stats(self, conversation_id: str) -> Dict[str, Any]:
"""
Get statistics for a specific conversation.
Args:
conversation_id: ID of conversation
Returns:
Dictionary with conversation statistics
"""
with self._lock:
turns = self.conversations.get(conversation_id, [])
if not turns:
return {
"turn_count": 0,
"total_tokens": 0,
"total_response_time": 0.0,
"average_response_time": 0.0,
"average_tokens": 0.0,
}
total_tokens = sum(turn.tokens_used for turn in turns)
total_response_time = sum(turn.response_time for turn in turns)
avg_response_time = total_response_time / len(turns)
avg_tokens = total_tokens / len(turns)
return {
"turn_count": len(turns),
"total_tokens": total_tokens,
"total_response_time": total_response_time,
"average_response_time": avg_response_time,
"average_tokens": avg_tokens,
"oldest_timestamp": min(turn.timestamp for turn in turns),
"newest_timestamp": max(turn.timestamp for turn in turns),
}
def cleanup_old_conversations(self, max_age_hours: float = 24.0) -> int:
"""
Clean up conversations older than specified age.
Args:
max_age_hours: Maximum age in hours before cleanup
Returns:
Number of conversations cleaned up
"""
with self._lock:
current_time = time.time()
cutoff_time = current_time - (max_age_hours * 3600)
conversations_to_remove = []
for conv_id, turns in self.conversations.items():
if turns and turns[-1].timestamp < cutoff_time:
conversations_to_remove.append(conv_id)
for conv_id in conversations_to_remove:
del self.conversations[conv_id]
if conversations_to_remove:
self.logger.info(f"Cleaned up {len(conversations_to_remove)} old conversations")
return len(conversations_to_remove)

View File

@@ -1,281 +0,0 @@
"""
Natural Timing Calculation for Mai
Provides human-like response delays based on cognitive load analysis
with natural variation to avoid robotic consistency.
"""
import time
import random
import logging
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
class TimingCalculator:
"""
Calculates natural response delays based on cognitive load analysis.
Generates human-like timing variation considering message complexity,
question count, technical content, and context depth.
"""
def __init__(self, profile: str = "default"):
"""
Initialize timing calculator with specified profile.
Args:
profile: Timing profile - "default", "fast", or "slow"
"""
self.profile = profile
self.logger = logging.getLogger(__name__)
# Profile-specific multipliers
self.profiles = {
"default": {"base": 1.0, "variation": 0.3},
"fast": {"base": 0.6, "variation": 0.2},
"slow": {"base": 1.4, "variation": 0.4},
}
if profile not in self.profiles:
self.logger.warning(f"Unknown profile '{profile}', using 'default'")
self.profile = "default"
self.profile_config = self.profiles[self.profile]
self.logger.info(f"TimingCalculator initialized with '{self.profile}' profile")
def calculate_response_delay(
self, message: str, context_complexity: Optional[int] = None
) -> float:
"""
Calculate natural response delay based on cognitive load.
Args:
message: User message to analyze
context_complexity: Optional context complexity score
Returns:
Response delay in seconds (1.0-8.0 range)
"""
# Analyze message complexity
complexity_score = self.get_complexity_score(message, context_complexity)
# Determine base delay based on complexity category
if complexity_score < 0.3:
# Simple (low complexity)
base_delay = random.uniform(1.5, 2.5)
category = "simple"
elif complexity_score < 0.7:
# Medium (moderate complexity)
base_delay = random.uniform(2.0, 4.0)
category = "medium"
else:
# Complex (high complexity)
base_delay = random.uniform(3.0, 8.0)
category = "complex"
# Apply profile multiplier
adjusted_delay = base_delay * self.profile_config["base"]
# Add natural variation/jitter
variation_amount = adjusted_delay * self.profile_config["variation"]
jitter = random.uniform(-0.2, 0.2) # +/-0.2 seconds
final_delay = max(0.5, adjusted_delay + variation_amount + jitter) # Minimum 0.5s
self.logger.debug(
f"Delay calculation: {category} complexity ({complexity_score:.2f}) -> {final_delay:.2f}s"
)
# Ensure within reasonable bounds
return min(max(final_delay, 0.5), 10.0) # 0.5s to 10s range
def get_complexity_score(self, message: str, context_complexity: Optional[int] = None) -> float:
"""
Analyze message content for complexity indicators.
Args:
message: Message to analyze
context_complexity: Optional context complexity from conversation history
Returns:
Complexity score from 0.0 (simple) to 1.0 (complex)
"""
score = 0.0
# 1. Message length factor (0-0.3)
word_count = len(message.split())
if word_count > 50:
score += 0.3
elif word_count > 25:
score += 0.2
elif word_count > 10:
score += 0.1
# 2. Question count factor (0-0.3)
question_count = message.count("?")
if question_count >= 3:
score += 0.3
elif question_count >= 2:
score += 0.2
elif question_count >= 1:
score += 0.1
# 3. Technical content indicators (0-0.3)
technical_keywords = [
"function",
"class",
"algorithm",
"debug",
"implement",
"fix",
"error",
"optimization",
"performance",
"database",
"api",
"endpoint",
"method",
"parameter",
"variable",
"constant",
"import",
"export",
"async",
"await",
"promise",
"callback",
"recursive",
"iterative",
"hash",
"encryption",
"authentication",
"authorization",
"token",
"session",
]
technical_count = sum(
1 for keyword in technical_keywords if keyword.lower() in message.lower()
)
if technical_count >= 5:
score += 0.3
elif technical_count >= 3:
score += 0.2
elif technical_count >= 1:
score += 0.1
# 4. Code pattern indicators (0-0.2)
code_indicators = 0
if "```" in message:
code_indicators += 1
if "`" in message and message.count("`") >= 2:
code_indicators += 1
if any(
word in message.lower() for word in ["def", "function", "class", "var", "let", "const"]
):
code_indicators += 1
if any(char in message for char in ["{}()\[\];"]):
code_indicators += 1
if code_indicators >= 1:
score += 0.1
if code_indicators >= 2:
score += 0.1
# 5. Context complexity integration (0-0.2)
if context_complexity is not None:
if context_complexity > 1000: # High token context
score += 0.2
elif context_complexity > 500: # Medium token context
score += 0.1
# Normalize to 0-1 range
normalized_score = min(score, 1.0)
self.logger.debug(
f"Complexity analysis: score={normalized_score:.2f}, words={word_count}, questions={question_count}, technical={technical_count}"
)
return normalized_score
def set_profile(self, profile: str) -> None:
"""
Change timing profile.
Args:
profile: New profile name ("default", "fast", "slow")
"""
if profile in self.profiles:
self.profile = profile
self.profile_config = self.profiles[profile]
self.logger.info(f"Timing profile changed to '{profile}'")
else:
self.logger.warning(
f"Unknown profile '{profile}', keeping current profile '{self.profile}'"
)
def get_timing_stats(self, messages: list) -> Dict[str, Any]:
"""
Calculate timing statistics for a list of messages.
Args:
messages: List of message dictionaries with timing info
Returns:
Dictionary with timing statistics
"""
if not messages:
return {
"message_count": 0,
"average_delay": 0.0,
"min_delay": 0.0,
"max_delay": 0.0,
"total_delay": 0.0,
}
delays = []
total_delay = 0.0
for msg in messages:
if "response_time" in msg:
delays.append(msg["response_time"])
total_delay += msg["response_time"]
if delays:
return {
"message_count": len(messages),
"average_delay": total_delay / len(delays),
"min_delay": min(delays),
"max_delay": max(delays),
"total_delay": total_delay,
"profile": self.profile,
}
else:
return {
"message_count": len(messages),
"average_delay": 0.0,
"min_delay": 0.0,
"max_delay": 0.0,
"total_delay": 0.0,
"profile": self.profile,
}
def get_profile_info(self) -> Dict[str, Any]:
"""
Get information about current timing profile.
Returns:
Dictionary with profile configuration
"""
return {
"current_profile": self.profile,
"base_multiplier": self.profile_config["base"],
"variation_range": self.profile_config["variation"],
"available_profiles": list(self.profiles.keys()),
"description": {
"default": "Natural human-like timing with moderate variation",
"fast": "Reduced delays for quick interactions and testing",
"slow": "Extended delays for thoughtful, deliberate responses",
}.get(self.profile, "Unknown profile"),
}

View File

@@ -1,13 +0,0 @@
"""
Mai Core Module
This module provides core functionality and utilities for Mai,
including configuration management, exception handling, and shared
utilities used across the application.
"""
# Import the real implementations instead of defining placeholders
from .exceptions import MaiError, ConfigurationError, ModelError
from .config import get_config
__all__ = ["MaiError", "ConfigurationError", "ModelError", "get_config"]

View File

@@ -1,738 +0,0 @@
"""
Configuration management system for Mai.
Handles loading, validation, and management of all Mai settings
with proper defaults and runtime updates.
"""
import os
import json
import yaml
from typing import Dict, Any, Optional, Union
from dataclasses import dataclass, field, asdict
from pathlib import Path
import copy
import threading
# Import exceptions
try:
from .exceptions import ConfigFileError, ConfigValidationError, ConfigMissingError
except ImportError:
# Define placeholder exceptions if module not available
class ConfigFileError(Exception):
pass
class ConfigValidationError(Exception):
pass
class ConfigMissingError(Exception):
pass
@dataclass
class ModelConfig:
"""Model-specific configuration."""
preferred_models: list = field(
default_factory=lambda: ["llama2", "mistral", "codellama", "vicuna"]
)
fallback_models: list = field(default_factory=lambda: ["llama2:7b", "mistral:7b", "phi"])
resource_thresholds: Dict[str, float] = field(
default_factory=lambda: {
"cpu_warning": 0.8,
"cpu_critical": 0.95,
"ram_warning": 0.8,
"ram_critical": 0.95,
"gpu_warning": 0.9,
"gpu_critical": 0.98,
}
)
context_windows: Dict[str, int] = field(
default_factory=lambda: {
"llama2": 4096,
"mistral": 8192,
"codellama": 16384,
"vicuna": 4096,
"phi": 2048,
}
)
auto_switch: bool = True
switch_threshold: float = 0.7 # Performance degradation threshold
@dataclass
class ResourceConfig:
"""Resource monitoring configuration."""
monitoring_enabled: bool = True
check_interval: float = 5.0 # seconds
trend_window: int = 60 # seconds for trend analysis
performance_history_size: int = 100
gpu_detection: bool = True
fallback_detection: bool = True
resource_warnings: bool = True
conservative_estimates: bool = True
memory_buffer: float = 0.5 # 50% buffer for context overhead
@dataclass
class ContextConfig:
"""Context management configuration."""
compression_enabled: bool = True
warning_threshold: float = 0.75 # Warn at 75% of context
critical_threshold: float = 0.90 # Critical at 90%
budget_ratio: float = 0.9 # Budget at 90% of context
max_conversation_length: int = 100
preserve_key_elements: bool = True
compression_cache_ttl: int = 3600 # 1 hour
min_quality_score: float = 0.7
@dataclass
class GitConfig:
"""Git workflow configuration."""
auto_commit: bool = True
commit_grouping: bool = True
natural_language_messages: bool = True
staging_branch: str = "mai-staging"
auto_merge: bool = True
health_checks: bool = True
stability_test_duration: int = 300 # 5 minutes
auto_revert: bool = True
commit_delay: float = 10.0 # seconds between commits
@dataclass
class LoggingConfig:
"""Logging and debugging configuration."""
level: str = "INFO"
file_logging: bool = True
console_logging: bool = True
log_file: str = "logs/mai.log"
max_file_size: int = 10 * 1024 * 1024 # 10MB
backup_count: int = 5
debug_mode: bool = False
performance_logging: bool = True
error_tracking: bool = True
@dataclass
class MemoryConfig:
"""Memory system and compression configuration."""
# Compression thresholds
message_count: int = 50
age_days: int = 30
memory_limit_mb: int = 500
# Summarization settings
summarization_model: str = "llama2"
preserve_elements: list = field(
default_factory=lambda: ["preferences", "decisions", "patterns", "key_facts"]
)
min_quality_score: float = 0.7
max_summary_length: int = 1000
context_messages: int = 30
# Adaptive weighting
importance_decay_days: int = 90
pattern_weight: float = 1.5
technical_weight: float = 1.2
planning_weight: float = 1.3
recency_boost: float = 1.2
keyword_boost: float = 1.5
# Strategy settings
keep_recent_count: int = 10
max_patterns_extracted: int = 5
topic_extraction_method: str = "keyword"
pattern_confidence_threshold: float = 0.6
# Retrieval settings
similarity_threshold: float = 0.7
max_results: int = 5
include_content: bool = False
semantic_weight: float = 0.4
keyword_weight: float = 0.3
recency_weight: float = 0.2
user_pattern_weight: float = 0.1
# Performance settings
max_memory_usage_mb: int = 200
max_cpu_usage_percent: int = 80
max_compression_time_seconds: int = 30
enable_background_compression: bool = True
compression_interval_hours: int = 6
batch_size: int = 5
@dataclass
class Config:
"""Main configuration class for Mai."""
models: ModelConfig = field(default_factory=ModelConfig)
resources: ResourceConfig = field(default_factory=ResourceConfig)
context: ContextConfig = field(default_factory=ContextConfig)
git: GitConfig = field(default_factory=GitConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
memory: MemoryConfig = field(default_factory=MemoryConfig)
# Runtime state
config_file: Optional[str] = None
last_modified: Optional[float] = None
_lock: threading.RLock = field(default_factory=threading.RLock)
def __post_init__(self):
"""Initialize configuration after dataclass creation."""
# Ensure log directory exists
if self.logging.file_logging:
log_path = Path(self.logging.log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
class ConfigManager:
"""
Configuration manager with loading, validation, and hot-reload capabilities.
"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize configuration manager.
Args:
config_path: Path to configuration file (YAML or JSON)
"""
self.config_path = config_path
self.config = Config()
self._observers = []
self._lock = threading.RLock()
# Load configuration if path provided
if config_path:
self.load_config(config_path)
# Apply environment variable overrides
self._apply_env_overrides()
def load_config(self, config_path: Optional[str] = None) -> Config:
"""
Load configuration from file.
Args:
config_path: Path to configuration file
Returns:
Loaded Config object
Raises:
ConfigFileError: If file cannot be loaded
ConfigValidationError: If configuration is invalid
"""
if config_path:
self.config_path = config_path
if not self.config_path or not os.path.exists(self.config_path):
# Use default configuration
self.config = Config()
return self.config
try:
with open(self.config_path, "r", encoding="utf-8") as f:
if self.config_path.endswith(".yaml") or self.config_path.endswith(".yml"):
data = yaml.safe_load(f)
elif self.config_path.endswith(".json"):
data = json.load(f)
else:
raise ConfigFileError(f"Unsupported config format: {self.config_path}")
# Merge with defaults
self.config = self._merge_with_defaults(data)
self.config.config_file = self.config_path
self.config.last_modified = os.path.getmtime(self.config_path)
# Validate configuration
self._validate_config()
# Apply environment overrides
self._apply_env_overrides()
# Notify observers
self._notify_observers("config_loaded", self.config)
return self.config
except (yaml.YAMLError, json.JSONDecodeError) as e:
raise ConfigFileError(f"Invalid configuration file format: {e}")
except Exception as e:
raise ConfigFileError(f"Error loading configuration: {e}")
def save_config(self, config_path: Optional[str] = None) -> bool:
"""
Save current configuration to file.
Args:
config_path: Path to save configuration (uses current if None)
Returns:
True if saved successfully
Raises:
ConfigFileError: If file cannot be saved
"""
if config_path:
self.config_path = config_path
if not self.config_path:
raise ConfigFileError("No configuration path specified")
try:
# Ensure directory exists
config_dir = os.path.dirname(self.config_path)
if config_dir:
os.makedirs(config_dir, exist_ok=True)
# Convert to dictionary
config_dict = asdict(self.config)
# Remove runtime state
config_dict.pop("config_file", None)
config_dict.pop("last_modified", None)
config_dict.pop("_lock", None)
# Save with comments (YAML format preferred)
with open(self.config_path, "w", encoding="utf-8") as f:
if self.config_path.endswith(".yaml") or self.config_path.endswith(".yml"):
# Add comments for documentation
yaml.dump(config_dict, f, default_flow_style=False, indent=2)
else:
json.dump(config_dict, f, indent=2)
self.config.last_modified = os.path.getmtime(self.config_path)
# Notify observers
self._notify_observers("config_saved", self.config)
return True
except Exception as e:
raise ConfigFileError(f"Error saving configuration: {e}")
def get_model_config(self) -> ModelConfig:
"""Get model-specific configuration."""
return self.config.models
def get_resource_config(self) -> ResourceConfig:
"""Get resource monitoring configuration."""
return self.config.resources
def get_context_config(self) -> ContextConfig:
"""Get context management configuration."""
return self.config.context
def get_git_config(self) -> GitConfig:
"""Get git workflow configuration."""
return self.config.git
def get_logging_config(self) -> LoggingConfig:
"""Get logging configuration."""
return self.config.logging
def get_memory_config(self) -> MemoryConfig:
"""Get memory configuration."""
return self.config.memory
def update_config(self, updates: Dict[str, Any], section: Optional[str] = None) -> bool:
"""
Update configuration with new values.
Args:
updates: Dictionary of updates to apply
section: Configuration section to update (optional)
Returns:
True if updated successfully
Raises:
ConfigValidationError: If updates are invalid
"""
with self._lock:
# Store old values for rollback
old_values = {}
try:
# Apply updates
if section:
if hasattr(self.config, section):
section_config = getattr(self.config, section)
for key, value in updates.items():
if hasattr(section_config, key):
old_values[f"{section}.{key}"] = getattr(section_config, key)
setattr(section_config, key, value)
else:
raise ConfigValidationError(f"Invalid config key: {section}.{key}")
else:
raise ConfigValidationError(f"Invalid config section: {section}")
else:
# Apply to root config
for key, value in updates.items():
if hasattr(self.config, key):
old_values[key] = getattr(self.config, key)
setattr(self.config, key, value)
else:
raise ConfigValidationError(f"Invalid config key: {key}")
# Validate updated configuration
self._validate_config()
# Save if file path available
if self.config_path:
self.save_config()
# Notify observers
self._notify_observers("config_updated", self.config, old_values)
return True
except Exception as e:
# Rollback changes on error
for path, value in old_values.items():
if "." in path:
section, key = path.split(".", 1)
if hasattr(self.config, section):
setattr(getattr(self.config, section), key, value)
else:
setattr(self.config, path, value)
raise ConfigValidationError(f"Invalid configuration update: {e}")
def reload_config(self) -> bool:
"""
Reload configuration from file.
Returns:
True if reloaded successfully
"""
if not self.config_path:
return False
try:
return self.load_config(self.config_path) is not None
except Exception:
return False
def add_observer(self, callback):
"""
Add observer for configuration changes.
Args:
callback: Function to call on config changes
"""
with self._lock:
self._observers.append(callback)
def remove_observer(self, callback):
"""
Remove observer for configuration changes.
Args:
callback: Function to remove
"""
with self._lock:
if callback in self._observers:
self._observers.remove(callback)
def _merge_with_defaults(self, data: Dict[str, Any]) -> Config:
"""
Merge loaded data with default configuration.
Args:
data: Loaded configuration data
Returns:
Merged Config object
"""
# Start with defaults
default_dict = asdict(Config())
# Recursively merge
merged = self._deep_merge(default_dict, data)
# Create Config from merged dict
return Config(**merged)
def _deep_merge(self, default: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
"""
Deep merge two dictionaries.
Args:
default: Default values
override: Override values
Returns:
Merged dictionary
"""
result = copy.deepcopy(default)
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = self._deep_merge(result[key], value)
else:
result[key] = value
return result
def _validate_config(self):
"""
Validate configuration values.
Raises:
ConfigValidationError: If configuration is invalid
"""
# Validate model config
if not self.config.models.preferred_models:
raise ConfigValidationError("No preferred models configured")
if not 0 <= self.config.models.switch_threshold <= 1:
raise ConfigValidationError("Model switch threshold must be between 0 and 1")
# Validate resource config
if not 0 < self.config.resources.check_interval <= 60:
raise ConfigValidationError("Resource check interval must be between 0 and 60 seconds")
# Validate context config
if not 0 < self.config.context.budget_ratio <= 1:
raise ConfigValidationError("Context budget ratio must be between 0 and 1")
if (
not 0
< self.config.context.warning_threshold
< self.config.context.critical_threshold
<= 1
):
raise ConfigValidationError("Invalid context thresholds: warning < critical <= 1")
# Validate logging config
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
if self.config.logging.level not in valid_levels:
raise ConfigValidationError(f"Invalid log level: {self.config.logging.level}")
def _apply_env_overrides(self):
"""Apply environment variable overrides."""
# Model overrides
if "MAI_PREFERRED_MODELS" in os.environ:
models = [m.strip() for m in os.environ["MAI_PREFERRED_MODELS"].split(",")]
self.config.models.preferred_models = models
if "MAI_AUTO_SWITCH" in os.environ:
self.config.models.auto_switch = os.environ["MAI_AUTO_SWITCH"].lower() == "true"
# Resource overrides
if "MAI_RESOURCE_MONITORING" in os.environ:
self.config.resources.monitoring_enabled = (
os.environ["MAI_RESOURCE_MONITORING"].lower() == "true"
)
# Context overrides
if "MAI_CONTEXT_BUDGET_RATIO" in os.environ:
try:
ratio = float(os.environ["MAI_CONTEXT_BUDGET_RATIO"])
if 0 < ratio <= 1:
self.config.context.budget_ratio = ratio
except ValueError:
pass
# Logging overrides
if "MAI_DEBUG_MODE" in os.environ:
self.config.logging.debug_mode = os.environ["MAI_DEBUG_MODE"].lower() == "true"
# Memory overrides
if "MAI_MEMORY_LIMIT_MB" in os.environ:
try:
limit = int(os.environ["MAI_MEMORY_LIMIT_MB"])
if limit > 0:
self.config.memory.memory_limit_mb = limit
except ValueError:
pass
if "MAI_COMPRESSION_MODEL" in os.environ:
self.config.memory.summarization_model = os.environ["MAI_COMPRESSION_MODEL"]
if "MAI_ENABLE_BACKGROUND_COMPRESSION" in os.environ:
self.config.memory.enable_background_compression = (
os.environ["MAI_ENABLE_BACKGROUND_COMPRESSION"].lower() == "true"
)
def _notify_observers(self, event: str, *args):
"""Notify observers of configuration changes."""
for observer in self._observers:
try:
observer(event, *args)
except Exception:
# Don't let observer errors break config management
pass
def get_config_summary(self) -> Dict[str, Any]:
"""
Get summary of current configuration.
Returns:
Dictionary with configuration summary
"""
return {
"config_file": self.config.config_file,
"last_modified": self.config.last_modified,
"models": {
"preferred_count": len(self.config.models.preferred_models),
"auto_switch": self.config.models.auto_switch,
"switch_threshold": self.config.models.switch_threshold,
},
"resources": {
"monitoring_enabled": self.config.resources.monitoring_enabled,
"check_interval": self.config.resources.check_interval,
"gpu_detection": self.config.resources.gpu_detection,
},
"context": {
"compression_enabled": self.config.context.compression_enabled,
"budget_ratio": self.config.context.budget_ratio,
"warning_threshold": self.config.context.warning_threshold,
},
"git": {
"auto_commit": self.config.git.auto_commit,
"staging_branch": self.config.git.staging_branch,
"auto_merge": self.config.git.auto_merge,
},
"logging": {
"level": self.config.logging.level,
"file_logging": self.config.logging.file_logging,
"debug_mode": self.config.logging.debug_mode,
},
"memory": {
"message_count": self.config.memory.message_count,
"age_days": self.config.memory.age_days,
"memory_limit_mb": self.config.memory.memory_limit_mb,
"summarization_model": self.config.memory.summarization_model,
"enable_background_compression": self.config.memory.enable_background_compression,
},
}
# Global configuration manager instance
_config_manager = None
def get_config_manager(config_path: Optional[str] = None) -> ConfigManager:
"""
Get global configuration manager instance.
Args:
config_path: Path to configuration file (only used on first call)
Returns:
ConfigManager instance
"""
global _config_manager
if _config_manager is None:
_config_manager = ConfigManager(config_path)
return _config_manager
def get_config(config_path: Optional[str] = None) -> Config:
"""
Get current configuration.
Args:
config_path: Optional path to configuration file (only used on first call)
Returns:
Current Config object
"""
return get_config_manager(config_path).config
def load_memory_config(config_path: Optional[str] = None) -> Dict[str, Any]:
"""
Load memory-specific configuration from YAML file.
Args:
config_path: Path to memory configuration file
Returns:
Dictionary with memory configuration settings
"""
# Default memory config path
if config_path is None:
config_path = os.path.join(".mai", "config", "memory.yaml")
# If file doesn't exist, return default settings
if not os.path.exists(config_path):
return {
"compression": {
"thresholds": {"message_count": 50, "age_days": 30, "memory_limit_mb": 500}
}
}
try:
with open(config_path, "r", encoding="utf-8") as f:
if config_path.endswith((".yaml", ".yml")):
config_data = yaml.safe_load(f)
else:
config_data = json.load(f)
# Validate and merge with defaults
default_config = {
"compression": {
"thresholds": {"message_count": 50, "age_days": 30, "memory_limit_mb": 500},
"summarization": {
"model": "llama2",
"preserve_elements": ["preferences", "decisions", "patterns", "key_facts"],
"min_quality_score": 0.7,
"max_summary_length": 1000,
"context_messages": 30,
},
}
}
# Deep merge with defaults
merged_config = _deep_merge(default_config, config_data)
# Validate key memory settings
compression_config = merged_config.get("compression", {})
thresholds = compression_config.get("thresholds", {})
if thresholds.get("message_count", 0) < 10:
raise ConfigValidationError(
field_name="message_count",
field_value=thresholds.get("message_count"),
validation_error="must be at least 10",
)
if thresholds.get("age_days", 0) < 1:
raise ConfigValidationError(
field_name="age_days",
field_value=thresholds.get("age_days"),
validation_error="must be at least 1 day",
)
if thresholds.get("memory_limit_mb", 0) < 100:
raise ConfigValidationError(
field_name="memory_limit_mb",
field_value=thresholds.get("memory_limit_mb"),
validation_error="must be at least 100MB",
)
return merged_config
except (yaml.YAMLError, json.JSONDecodeError) as e:
raise ConfigFileError(
file_path=config_path,
operation="load_memory_config",
error_details=f"Invalid format: {e}",
)
except Exception as e:
raise ConfigFileError(
file_path=config_path,
operation="load_memory_config",
error_details=f"Error loading: {e}",
)

View File

@@ -1,834 +0,0 @@
"""
Custom exception hierarchy for Mai error handling.
Provides clear, actionable error information for all Mai components
with context data and resolution suggestions.
"""
from typing import Dict, Any, Optional, List
from dataclasses import dataclass, field
import traceback
import time
@dataclass
class ErrorContext:
"""Context information for errors."""
component: str # Component where error occurred
operation: str # Operation being performed
data: Dict[str, Any] # Relevant context data
timestamp: float = field(default_factory=time.time) # When error occurred
user_friendly: bool = True # Whether to show to users
class MaiError(Exception):
"""
Base exception for all Mai-specific errors.
All Mai exceptions should inherit from this class to provide
consistent error handling and context.
"""
def __init__(
self,
message: str,
error_code: Optional[str] = None,
context: Optional[ErrorContext] = None,
suggestions: Optional[List[str]] = None,
cause: Optional[Exception] = None,
):
"""
Initialize Mai error.
Args:
message: Error message
error_code: Unique error code for programmatic handling
context: Error context information
suggestions: Suggestions for resolution
cause: Original exception that caused this error
"""
super().__init__(message)
self.message = message
self.error_code = error_code or self.__class__.__name__
self.context = context or ErrorContext(
component="unknown", operation="unknown", data={}, timestamp=time.time()
)
self.suggestions = suggestions or []
self.cause = cause
self.severity = self._determine_severity()
def _determine_severity(self) -> str:
"""Determine error severity based on type and context."""
if (
"Critical" in self.__class__.__name__
or self.error_code
and "CRITICAL" in self.error_code
):
return "critical"
elif (
"Warning" in self.__class__.__name__ or self.error_code and "WARNING" in self.error_code
):
return "warning"
else:
return "error"
def to_dict(self) -> Dict[str, Any]:
"""Convert error to dictionary for serialization."""
return {
"error_type": self.__class__.__name__,
"message": self.message,
"error_code": self.error_code,
"severity": self.severity,
"context": {
"component": self.context.component,
"operation": self.context.operation,
"data": self.context.data,
"timestamp": self.context.timestamp,
"user_friendly": self.context.user_friendly,
},
"suggestions": self.suggestions,
"cause": str(self.cause) if self.cause else None,
"traceback": traceback.format_exc() if self.severity == "critical" else None,
}
def __str__(self) -> str:
"""String representation of error."""
return self.message
class ModelError(MaiError):
"""Base class for model-related errors."""
def __init__(self, message: str, model_name: Optional[str] = None, **kwargs):
kwargs.setdefault(
"context",
ErrorContext(
component="model_interface",
operation="model_operation",
data={"model_name": model_name} if model_name else {},
),
)
super().__init__(message, **kwargs)
self.model_name = model_name
class ModelNotFoundError(ModelError):
"""Raised when requested model is not available."""
def __init__(self, model_name: str, available_models: Optional[List[str]] = None):
suggestions = [
f"Check if '{model_name}' is installed in Ollama",
"Run 'ollama list' to see available models",
"Try downloading the model with 'ollama pull'",
]
if available_models:
suggestions.append(f"Available models: {', '.join(available_models[:5])}")
super().__init__(
f"Model '{model_name}' not found",
model_name=model_name,
error_code="MODEL_NOT_FOUND",
suggestions=suggestions,
)
self.available_models = available_models or []
class ModelSwitchError(ModelError):
"""Raised when model switching fails."""
def __init__(self, from_model: str, to_model: str, reason: Optional[str] = None):
message = f"Failed to switch from '{from_model}' to '{to_model}'"
if reason:
message += f": {reason}"
suggestions = [
"Check if target model is available",
"Verify sufficient system resources for target model",
"Try switching to a smaller model first",
]
super().__init__(
message,
model_name=to_model,
error_code="MODEL_SWITCH_FAILED",
context=ErrorContext(
component="model_switcher",
operation="switch_model",
data={"from_model": from_model, "to_model": to_model, "reason": reason},
),
suggestions=suggestions,
)
self.from_model = from_model
self.to_model = to_model
class ModelConnectionError(ModelError):
"""Raised when cannot connect to Ollama or model service."""
def __init__(self, service_url: str, timeout: Optional[float] = None):
message = f"Cannot connect to model service at {service_url}"
if timeout:
message += f" (timeout: {timeout}s)"
suggestions = [
"Check if Ollama is running",
f"Verify service URL: {service_url}",
"Check network connectivity",
"Try restarting Ollama service",
]
super().__init__(
message,
error_code="MODEL_CONNECTION_FAILED",
context=ErrorContext(
component="ollama_client",
operation="connect",
data={"service_url": service_url, "timeout": timeout},
),
suggestions=suggestions,
)
self.service_url = service_url
self.timeout = timeout
class ModelInferenceError(ModelError):
"""Raised when model inference request fails."""
def __init__(self, model_name: str, prompt_length: int, error_details: Optional[str] = None):
message = f"Inference failed for model '{model_name}'"
if error_details:
message += f": {error_details}"
suggestions = [
"Check if model is loaded properly",
"Try with a shorter prompt",
"Verify model context window limits",
"Check available system memory",
]
super().__init__(
message,
model_name=model_name,
error_code="MODEL_INFERENCE_FAILED",
context=ErrorContext(
component="model_interface",
operation="inference",
data={
"model_name": model_name,
"prompt_length": prompt_length,
"error_details": error_details,
},
),
suggestions=suggestions,
)
self.prompt_length = prompt_length
self.error_details = error_details
class ResourceError(MaiError):
"""Base class for resource-related errors."""
def __init__(self, message: str, **kwargs):
kwargs.setdefault(
"context",
ErrorContext(component="resource_monitor", operation="resource_check", data={}),
)
super().__init__(message, **kwargs)
class ResourceExhaustedError(ResourceError):
"""Raised when system resources are depleted."""
def __init__(self, resource_type: str, current_usage: float, limit: float):
message = (
f"Resource '{resource_type}' exhausted: {current_usage:.1%} used (limit: {limit:.1%})"
)
suggestions = [
"Close other applications to free up resources",
"Try using a smaller model",
"Wait for resources to become available",
"Consider upgrading system resources",
]
super().__init__(
message,
error_code="RESOURCE_EXHAUSTED",
context=ErrorContext(
component="resource_monitor",
operation="check_resources",
data={
"resource_type": resource_type,
"current_usage": current_usage,
"limit": limit,
"excess": current_usage - limit,
},
),
suggestions=suggestions,
)
self.resource_type = resource_type
self.current_usage = current_usage
self.limit = limit
class ResourceMonitorError(ResourceError):
"""Raised when resource monitoring fails."""
def __init__(self, operation: str, error_details: Optional[str] = None):
message = f"Resource monitoring failed during {operation}"
if error_details:
message += f": {error_details}"
suggestions = [
"Check if monitoring dependencies are installed",
"Verify system permissions for resource access",
"Try using fallback monitoring methods",
"Restart the application",
]
super().__init__(
message,
error_code="RESOURCE_MONITOR_FAILED",
context=ErrorContext(
component="resource_monitor",
operation=operation,
data={"error_details": error_details},
),
suggestions=suggestions,
)
self.operation = operation
self.error_details = error_details
class InsufficientMemoryError(ResourceError):
"""Raised when insufficient memory for operation."""
def __init__(self, required_memory: int, available_memory: int, operation: str):
message = f"Insufficient memory for {operation}: need {required_memory}MB, have {available_memory}MB"
suggestions = [
"Close other applications to free memory",
"Try with a smaller model or context",
"Increase swap space if available",
"Consider using a model with lower memory requirements",
]
super().__init__(
message,
error_code="INSUFFICIENT_MEMORY",
context=ErrorContext(
component="memory_manager",
operation="allocate_memory",
data={
"required_memory": required_memory,
"available_memory": available_memory,
"shortfall": required_memory - available_memory,
"operation": operation,
},
),
suggestions=suggestions,
)
self.required_memory = required_memory
self.available_memory = available_memory
self.operation = operation
class ContextError(MaiError):
"""Base class for context-related errors."""
def __init__(self, message: str, **kwargs):
kwargs.setdefault(
"context",
ErrorContext(component="context_manager", operation="context_operation", data={}),
)
super().__init__(message, **kwargs)
class ContextTooLongError(ContextError):
"""Raised when conversation exceeds context window limits."""
def __init__(self, current_tokens: int, max_tokens: int, model_name: str):
message = (
f"Conversation too long for {model_name}: {current_tokens} tokens (max: {max_tokens})"
)
suggestions = [
"Enable context compression",
"Remove older messages from conversation",
"Use a model with larger context window",
"Split conversation into smaller parts",
]
super().__init__(
message,
error_code="CONTEXT_TOO_LONG",
context=ErrorContext(
component="context_compressor",
operation="validate_context",
data={
"current_tokens": current_tokens,
"max_tokens": max_tokens,
"excess": current_tokens - max_tokens,
"model_name": model_name,
},
),
suggestions=suggestions,
)
self.current_tokens = current_tokens
self.max_tokens = max_tokens
self.model_name = model_name
class ContextCompressionError(ContextError):
"""Raised when context compression fails."""
def __init__(
self, original_tokens: int, target_ratio: float, error_details: Optional[str] = None
):
message = (
f"Context compression failed: {original_tokens} tokens → target {target_ratio:.1%}"
)
if error_details:
message += f": {error_details}"
suggestions = [
"Try with a higher compression ratio",
"Check if conversation contains valid text",
"Verify compression quality thresholds",
"Use manual message removal instead",
]
super().__init__(
message,
error_code="CONTEXT_COMPRESSION_FAILED",
context=ErrorContext(
component="context_compressor",
operation="compress",
data={
"original_tokens": original_tokens,
"target_ratio": target_ratio,
"error_details": error_details,
},
),
suggestions=suggestions,
)
self.original_tokens = original_tokens
self.target_ratio = target_ratio
self.error_details = error_details
class ContextCorruptionError(ContextError):
"""Raised when context data is invalid or corrupted."""
def __init__(self, context_type: str, corruption_details: Optional[str] = None):
message = f"Context corruption detected in {context_type}"
if corruption_details:
message += f": {corruption_details}"
suggestions = [
"Clear conversation history and start fresh",
"Verify context serialization format",
"Check for data encoding issues",
"Rebuild context from valid messages",
]
super().__init__(
message,
error_code="CONTEXT_CORRUPTION",
context=ErrorContext(
component="context_manager",
operation="validate_context",
data={"context_type": context_type, "corruption_details": corruption_details},
),
suggestions=suggestions,
)
self.context_type = context_type
self.corruption_details = corruption_details
class GitError(MaiError):
"""Base class for Git-related errors."""
def __init__(self, message: str, **kwargs):
kwargs.setdefault(
"context", ErrorContext(component="git_interface", operation="git_operation", data={})
)
super().__init__(message, **kwargs)
class GitRepositoryError(GitError):
"""Raised for Git repository issues."""
def __init__(self, repo_path: str, error_details: Optional[str] = None):
message = f"Git repository error in {repo_path}"
if error_details:
message += f": {error_details}"
suggestions = [
"Verify directory is a Git repository",
"Check Git repository permissions",
"Run 'git status' to diagnose issues",
"Initialize repository with 'git init' if needed",
]
super().__init__(
message,
error_code="GIT_REPOSITORY_ERROR",
context=ErrorContext(
component="git_interface",
operation="validate_repository",
data={"repo_path": repo_path, "error_details": error_details},
),
suggestions=suggestions,
)
self.repo_path = repo_path
self.error_details = error_details
class GitCommitError(GitError):
"""Raised when commit operation fails."""
def __init__(
self, operation: str, files: Optional[List[str]] = None, error_details: Optional[str] = None
):
message = f"Git {operation} failed"
if error_details:
message += f": {error_details}"
suggestions = [
"Check if files exist and are readable",
"Verify write permissions for repository",
"Run 'git status' to check repository state",
"Stage files with 'git add' before committing",
]
super().__init__(
message,
error_code="GIT_COMMIT_FAILED",
context=ErrorContext(
component="git_committer",
operation=operation,
data={"files": files or [], "error_details": error_details},
),
suggestions=suggestions,
)
self.operation = operation
self.files = files or []
self.error_details = error_details
class GitMergeError(GitError):
"""Raised for merge conflicts or failures."""
def __init__(
self,
branch_name: str,
conflict_files: Optional[List[str]] = None,
error_details: Optional[str] = None,
):
message = f"Git merge failed for branch '{branch_name}'"
if error_details:
message += f": {error_details}"
suggestions = [
"Resolve merge conflicts manually",
"Use 'git status' to see conflicted files",
"Consider using 'git merge --abort' to cancel",
"Pull latest changes before merging",
]
super().__init__(
message,
error_code="GIT_MERGE_FAILED",
context=ErrorContext(
component="git_workflow",
operation="merge",
data={
"branch_name": branch_name,
"conflict_files": conflict_files or [],
"error_details": error_details,
},
),
suggestions=suggestions,
)
self.branch_name = branch_name
self.conflict_files = conflict_files or []
self.error_details = error_details
class ConfigurationError(MaiError):
"""Base class for configuration-related errors."""
def __init__(self, message: str, **kwargs):
kwargs.setdefault(
"context",
ErrorContext(component="config_manager", operation="config_operation", data={}),
)
super().__init__(message, **kwargs)
class ConfigFileError(ConfigurationError):
"""Raised for configuration file issues."""
def __init__(self, file_path: str, operation: str, error_details: Optional[str] = None):
message = f"Configuration file error during {operation}: {file_path}"
if error_details:
message += f": {error_details}"
suggestions = [
"Verify file path and permissions",
"Check file format (YAML/JSON)",
"Ensure file contains valid configuration",
"Create default configuration file if missing",
]
super().__init__(
message,
error_code="CONFIG_FILE_ERROR",
context=ErrorContext(
component="config_manager",
operation=operation,
data={"file_path": file_path, "error_details": error_details},
),
suggestions=suggestions,
)
self.file_path = file_path
self.operation = operation
self.error_details = error_details
class ConfigValidationError(ConfigurationError):
"""Raised for invalid configuration values."""
def __init__(self, field_name: str, field_value: Any, validation_error: str):
message = (
f"Invalid configuration value for '{field_name}': {field_value} - {validation_error}"
)
suggestions = [
"Check configuration documentation for valid values",
"Verify value type and range constraints",
"Use default configuration values",
"Check for typos in field names",
]
super().__init__(
message,
error_code="CONFIG_VALIDATION_FAILED",
context=ErrorContext(
component="config_manager",
operation="validate_config",
data={
"field_name": field_name,
"field_value": str(field_value),
"validation_error": validation_error,
},
),
suggestions=suggestions,
)
self.field_name = field_name
self.field_value = field_value
self.validation_error = validation_error
class ConfigMissingError(ConfigurationError):
"""Raised when required configuration is missing."""
def __init__(self, missing_keys: List[str], config_section: Optional[str] = None):
section_msg = f" in section '{config_section}'" if config_section else ""
message = f"Required configuration missing{section_msg}: {', '.join(missing_keys)}"
suggestions = [
"Add missing keys to configuration file",
"Check configuration documentation for required fields",
"Use default configuration as template",
"Verify configuration file is being loaded correctly",
]
super().__init__(
message,
error_code="CONFIG_MISSING_REQUIRED",
context=ErrorContext(
component="config_manager",
operation="check_requirements",
data={"missing_keys": missing_keys, "config_section": config_section},
),
suggestions=suggestions,
)
self.missing_keys = missing_keys
self.config_section = config_section
# Error handling utilities
def format_error_for_user(error: MaiError) -> str:
"""
Convert technical error to user-friendly message.
Args:
error: MaiError instance
Returns:
User-friendly error message
"""
if not isinstance(error, MaiError):
return f"Unexpected error: {str(error)}"
# Use the message if it's user-friendly
if error.context.user_friendly:
return str(error)
# Create user-friendly version
friendly_message = error.message
# Remove technical details
technical_terms = ["traceback", "exception", "error_code", "context"]
for term in technical_terms:
friendly_message = friendly_message.lower().replace(term, "")
# Add top suggestion
if error.suggestions:
friendly_message += f"\n\nSuggestion: {error.suggestions[0]}"
return friendly_message.strip()
def is_retriable_error(error: Exception) -> bool:
"""
Determine if error can be retried.
Args:
error: Exception instance
Returns:
True if error is retriable
"""
if isinstance(error, MaiError):
retriable_codes = [
"MODEL_CONNECTION_FAILED",
"RESOURCE_MONITOR_FAILED",
"CONTEXT_COMPRESSION_FAILED",
]
return error.error_code in retriable_codes
# Non-Mai errors: only retry network/connection issues
error_str = str(error).lower()
retriable_patterns = ["connection", "timeout", "network", "temporary", "unavailable"]
return any(pattern in error_str for pattern in retriable_patterns)
def get_error_severity(error: Exception) -> str:
"""
Classify error severity.
Args:
error: Exception instance
Returns:
Severity level: 'warning', 'error', or 'critical'
"""
if isinstance(error, MaiError):
return error.severity
# Classify non-Mai errors
error_str = str(error).lower()
if any(pattern in error_str for pattern in ["critical", "fatal"]):
return "critical"
elif any(pattern in error_str for pattern in ["warning"]):
return "warning"
else:
return "error"
def create_error_context(component: str, operation: str, **data) -> ErrorContext:
"""
Create error context with current timestamp.
Args:
component: Component name
operation: Operation name
**data: Additional context data
Returns:
ErrorContext instance
"""
return ErrorContext(component=component, operation=operation, data=data, timestamp=time.time())
# Exception handler for logging and monitoring
class ErrorHandler:
"""
Central error handler for Mai components.
Provides consistent error logging, metrics, and user notification.
"""
def __init__(self, logger=None):
"""
Initialize error handler.
Args:
logger: Logger instance for error reporting
"""
self.logger = logger
self.error_counts = {}
self.last_errors = {}
def handle_error(self, error: Exception, component: str = "unknown"):
"""
Handle error with logging and metrics.
Args:
error: Exception to handle
component: Component where error occurred
"""
# Count errors
error_type = error.__class__.__name__
self.error_counts[error_type] = self.error_counts.get(error_type, 0) + 1
self.last_errors[error_type] = {
"error": error,
"component": component,
"timestamp": time.time(),
}
# Log error
if self.logger:
severity = get_error_severity(error)
if severity == "critical":
self.logger.critical(f"Critical error in {component}: {error}")
elif severity == "error":
self.logger.error(f"Error in {component}: {error}")
else:
self.logger.warning(f"Warning in {component}: {error}")
# Return formatted error for user
if isinstance(error, MaiError):
return format_error_for_user(error)
else:
return f"An error occurred in {component}: {str(error)}"
def get_error_stats(self) -> Dict[str, Any]:
"""
Get error statistics.
Returns:
Dictionary with error statistics
"""
return {
"error_counts": self.error_counts.copy(),
"last_errors": {
k: {
"error": str(v["error"]),
"component": v["component"],
"timestamp": v["timestamp"],
}
for k, v in self.last_errors.items()
},
"total_errors": sum(self.error_counts.values()),
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +0,0 @@
"""
Git workflow management for Mai's self-improvement system.
Provides staging branch management, validation, and cleanup
capabilities for safe code improvements.
"""
from .workflow import StagingWorkflow
from .committer import AutoCommitter
from .health_check import HealthChecker
__all__ = ["StagingWorkflow", "AutoCommitter", "HealthChecker"]

View File

@@ -1,499 +0,0 @@
"""
Automated commit generation and management for Mai's self-improvement system.
Handles staging changes, generating user-focused commit messages,
and managing commit history with proper validation.
"""
import os
import re
import logging
from datetime import datetime
from typing import List, Dict, Optional, Any, Set
from pathlib import Path
try:
from git import Repo, InvalidGitRepositoryError, GitCommandError, Diff, GitError
except ImportError:
raise ImportError("GitPython is required. Install with: pip install GitPython")
from ..core import MaiError, ConfigurationError
class AutoCommitterError(MaiError):
"""Raised when automated commit operations fail."""
pass
class AutoCommitter:
"""
Automates commit generation and management for Mai's improvements.
Provides staging, commit message generation, and history management
with user-focused impact descriptions.
"""
def __init__(self, project_path: str = "."):
"""
Initialize auto committer.
Args:
project_path: Path to git repository
Raises:
ConfigurationError: If not a git repository
"""
self.project_path = Path(project_path).resolve()
self.logger = logging.getLogger(__name__)
try:
self.repo = Repo(self.project_path)
except InvalidGitRepositoryError:
raise ConfigurationError(f"Not a git repository: {self.project_path}")
# Commit message templates and patterns
self.templates = {
"performance": "Faster {operation} for {scenario}",
"bugfix": "Fixed {issue} - {impact on user}",
"feature": "Added {capability} - now you can {user benefit}",
"optimization": "Improved {system} - {performance gain}",
"refactor": "Cleaned up {component} - {improvement}",
"security": "Enhanced security for {area} - {protection}",
"compatibility": "Made Mai work better with {environment} - {benefit}",
}
# File patterns to ignore
self.ignore_patterns = {
"*.pyc",
"*.pyo",
"*.pyd",
"__pycache__",
".git",
".pytest_cache",
".coverage",
"htmlcov",
"*.log",
".env",
"*.tmp",
"*.temp",
"*.bak",
".DS_Store",
"*.swp",
"*~",
}
# Group patterns by system
self.group_patterns = {
"model": ["src/mai/model/", "*.model.*"],
"git": ["src/mai/git/", "*.git.*"],
"core": ["src/mai/core/", "*.core.*"],
"memory": ["src/mai/memory/", "*.memory.*"],
"safety": ["src/mai/safety/", "*.safety.*"],
"personality": ["src/mai/personality/", "*.personality.*"],
"interface": ["src/mai/interface/", "*.interface.*"],
"config": ["*.toml", "*.yaml", "*.yml", "*.conf", ".env*"],
}
# Initialize user information
self._init_user_info()
self.logger.info(f"Auto committer initialized for {self.project_path}")
def stage_changes(
self, file_patterns: Optional[List[str]] = None, group_by: str = "system"
) -> Dict[str, Any]:
"""
Stage changed files for commit with optional grouping.
Args:
file_patterns: Specific file patterns to stage
group_by: How to group changes ("system", "directory", "none")
Returns:
Dictionary with staging results and groups
"""
try:
# Get changed files
changed_files = self._get_changed_files()
# Filter by patterns if specified
if file_patterns:
changed_files = [
f for f in changed_files if self._matches_pattern(f, file_patterns)
]
# Filter out ignored files
staged_files = [f for f in changed_files if not self._should_ignore_file(f)]
# Stage the files
self.repo.index.add(staged_files)
# Group changes
groups = self._group_changes(staged_files, group_by) if group_by != "none" else {}
self.logger.info(f"Staged {len(staged_files)} files in {len(groups)} groups")
return {
"staged_files": staged_files,
"groups": groups,
"total_files": len(staged_files),
"message": f"Staged {len(staged_files)} files for commit",
}
except (GitError, GitCommandError) as e:
raise AutoCommitterError(f"Failed to stage changes: {e}")
def generate_commit_message(
self, changes: List[str], impact_description: str, improvement_type: str = "feature"
) -> str:
"""
Generate user-focused commit message.
Args:
changes: List of changed files
impact_description: Description of impact on user
improvement_type: Type of improvement
Returns:
User-focused commit message
"""
# Try to use template
if improvement_type in self.templates:
template = self.templates[improvement_type]
# Extract context from changes
context = self._extract_context_from_files(changes)
# Fill template
try:
message = template.format(**context, **{"user benefit": impact_description})
except KeyError:
# Fall back to impact description
message = impact_description
else:
message = impact_description
# Ensure user-focused language
message = self._make_user_focused(message)
# Add technical details as second line
if len(changes) <= 5:
tech_details = f"Files: {', '.join([Path(f).name for f in changes[:3]])}"
if len(changes) > 3:
tech_details += f" (+{len(changes) - 3} more)"
message = f"{message}\n\n{tech_details}"
# Limit length
if len(message) > 100:
message = message[:97] + "..."
return message
def commit_changes(
self, message: str, files: Optional[List[str]] = None, validate_before: bool = True
) -> Dict[str, Any]:
"""
Create commit with generated message and optional validation.
Args:
message: Commit message
files: Specific files to commit (stages all if None)
validate_before: Run validation before committing
Returns:
Dictionary with commit results
"""
try:
# Validate if requested
if validate_before:
validation = self._validate_commit(message, files)
if not validation["valid"]:
return {
"success": False,
"message": "Commit validation failed",
"validation": validation,
"commit_hash": None,
}
# Stage files if specified
if files:
self.repo.index.add(files)
# Check if there are staged changes
if not self.repo.is_dirty(untracked_files=True) and not self.repo.index.diff("HEAD"):
return {"success": False, "message": "No changes to commit", "commit_hash": None}
# Create commit with metadata
commit = self.repo.index.commit(
message=message, author_date=datetime.now(), committer_date=datetime.now()
)
commit_hash = commit.hexsha
self.logger.info(f"Created commit: {commit_hash[:8]} - {message[:50]}")
return {
"success": True,
"message": f"Committed {commit_hash[:8]}",
"commit_hash": commit_hash,
"short_hash": commit_hash[:8],
"full_message": message,
}
except (GitError, GitCommandError) as e:
raise AutoCommitterError(f"Failed to create commit: {e}")
def get_commit_history(
self, limit: int = 10, filter_by: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Retrieve commit history with metadata.
Args:
limit: Maximum number of commits to retrieve
filter_by: Filter criteria (author, date range, patterns)
Returns:
List of commit information
"""
try:
commits = []
for commit in self.repo.iter_commits(max_count=limit):
# Apply filters
if filter_by:
if "author" in filter_by and filter_by["author"] not in commit.author.name:
continue
if "since" in filter_by and commit.committed_date < filter_by["since"]:
continue
if "until" in filter_by and commit.committed_date > filter_by["until"]:
continue
if "pattern" in filter_by and not re.search(
filter_by["pattern"], commit.message
):
continue
commits.append(
{
"hash": commit.hexsha,
"short_hash": commit.hexsha[:8],
"message": commit.message.strip(),
"author": commit.author.name,
"date": datetime.fromtimestamp(commit.committed_date).isoformat(),
"files_changed": len(commit.stats.files),
"insertions": commit.stats.total["insertions"],
"deletions": commit.stats.total["deletions"],
"impact": self._extract_impact_from_message(commit.message),
}
)
return commits
except (GitError, GitCommandError) as e:
raise AutoCommitterError(f"Failed to get commit history: {e}")
def revert_commit(self, commit_hash: str, create_branch: bool = True) -> Dict[str, Any]:
"""
Safely revert specified commit.
Args:
commit_hash: Hash of commit to revert
create_branch: Create backup branch before reverting
Returns:
Dictionary with revert results
"""
try:
# Validate commit exists
try:
commit = self.repo.commit(commit_hash)
except Exception:
return {
"success": False,
"message": f"Commit {commit_hash[:8]} not found",
"commit_hash": None,
}
# Create backup branch if requested
backup_branch = None
if create_branch:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
backup_branch = f"backup/before-revert-{commit_hash[:8]}-{timestamp}"
self.repo.create_head(backup_branch, self.repo.active_branch.commit)
self.logger.info(f"Created backup branch: {backup_branch}")
# Perform revert
revert_commit = self.repo.git.revert("--no-edit", commit_hash)
# Get new commit hash
new_commit_hash = self.repo.head.commit.hexsha
self.logger.info(f"Reverted commit {commit_hash[:8]} -> {new_commit_hash[:8]}")
return {
"success": True,
"message": f"Reverted {commit_hash[:8]} successfully",
"original_commit": commit_hash,
"new_commit_hash": new_commit_hash,
"new_short_hash": new_commit_hash[:8],
"backup_branch": backup_branch,
"original_message": commit.message.strip(),
}
except (GitError, GitCommandError) as e:
raise AutoCommitterError(f"Failed to revert commit: {e}")
def _get_changed_files(self) -> List[str]:
"""Get list of changed files in working directory."""
changed_files = set()
# Unstaged changes
for item in self.repo.index.diff(None):
changed_files.add(item.a_path)
# Staged changes
for item in self.repo.index.diff("HEAD"):
changed_files.add(item.a_path)
# Untracked files
changed_files.update(self.repo.untracked_files)
return list(changed_files)
def _should_ignore_file(self, file_path: str) -> bool:
"""Check if file should be ignored."""
file_name = Path(file_path).name
for pattern in self.ignore_patterns:
if self._matches_pattern(file_path, [pattern]):
return True
return False
def _matches_pattern(self, file_path: str, patterns: List[str]) -> bool:
"""Check if file path matches any pattern."""
import fnmatch
for pattern in patterns:
if fnmatch.fnmatch(file_path, pattern) or fnmatch.fnmatch(
Path(file_path).name, pattern
):
return True
return False
def _group_changes(self, files: List[str], group_by: str) -> Dict[str, List[str]]:
"""Group files by system or directory."""
groups = {}
if group_by == "system":
for file_path in files:
group = "other"
for system, patterns in self.group_patterns.items():
if self._matches_pattern(file_path, patterns):
group = system
break
if group not in groups:
groups[group] = []
groups[group].append(file_path)
elif group_by == "directory":
for file_path in files:
directory = str(Path(file_path).parent)
if directory not in groups:
groups[directory] = []
groups[directory].append(file_path)
return groups
def _extract_context_from_files(self, files: List[str]) -> Dict[str, str]:
"""Extract context from changed files."""
context = {}
# Analyze file paths for context
model_files = [f for f in files if "model" in f.lower()]
git_files = [f for f in files if "git" in f.lower()]
core_files = [f for f in files if "core" in f.lower()]
if model_files:
context["system"] = "model interface"
context["operation"] = "model operations"
elif git_files:
context["system"] = "git workflows"
context["operation"] = "version control"
elif core_files:
context["system"] = "core functionality"
context["operation"] = "system stability"
else:
context["system"] = "Mai"
context["operation"] = "functionality"
# Default scenario
context["scenario"] = "your conversations"
context["area"] = "Mai's capabilities"
return context
def _make_user_focused(self, message: str) -> str:
"""Convert message to be user-focused."""
# Remove technical jargon
replacements = {
"feat:": "",
"fix:": "",
"refactor:": "",
"optimize:": "",
"implementation": "new capability",
"functionality": "features",
"module": "component",
"code": "improvements",
"api": "interface",
"backend": "core system",
}
for old, new in replacements.items():
message = message.replace(old, new)
# Start with action verb if needed
if not message[0].isupper():
message = message[0].upper() + message[1:]
return message.strip()
def _validate_commit(self, message: str, files: Optional[List[str]]) -> Dict[str, Any]:
"""Validate commit before creation."""
issues = []
# Check message length
if len(message) > 100:
issues.append("Commit message too long (>100 characters)")
# Check message has content
if not message.strip():
issues.append("Empty commit message")
# Check for files if specified
if files and not files:
issues.append("No files specified for commit")
return {"valid": len(issues) == 0, "issues": issues}
def _extract_impact_from_message(self, message: str) -> str:
"""Extract impact description from commit message."""
# Split by lines and take first non-empty line
lines = message.strip().split("\n")
for line in lines:
line = line.strip()
if line and not line.startswith("Files:"):
return line
return message
def _init_user_info(self) -> None:
"""Initialize user information from git config."""
try:
config = self.repo.config_reader()
self.user_name = config.get_value("user", "name", "Mai")
self.user_email = config.get_value("user", "email", "mai@local")
except Exception:
self.user_name = "Mai"
self.user_email = "mai@local"

File diff suppressed because it is too large Load Diff

View File

@@ -1,399 +0,0 @@
"""
Staging workflow management for Mai's self-improvement system.
Handles branch creation, management, and cleanup for testing improvements
before merging to main codebase.
"""
import os
import time
import logging
from datetime import datetime, timedelta
from typing import List, Dict, Optional, Tuple, Any
from pathlib import Path
try:
from git import Repo, InvalidGitRepositoryError, GitCommandError, Head
from git.exc import GitError
except ImportError:
raise ImportError("GitPython is required. Install with: pip install GitPython")
from ..core import MaiError, ConfigurationError
class StagingWorkflowError(MaiError):
"""Raised when staging workflow operations fail."""
pass
class StagingWorkflow:
"""
Manages staging branches for safe code improvements.
Provides branch creation, validation, and cleanup capabilities
with proper error handling and recovery.
"""
def __init__(self, project_path: str = ".", timeout: int = 30):
"""
Initialize staging workflow.
Args:
project_path: Path to git repository
timeout: Timeout for git operations in seconds
Raises:
ConfigurationError: If not a git repository
"""
self.project_path = Path(project_path).resolve()
self.timeout = timeout
self.logger = logging.getLogger(__name__)
try:
self.repo = Repo(self.project_path)
except InvalidGitRepositoryError:
raise ConfigurationError(f"Not a git repository: {self.project_path}")
# Configure retry logic for git operations
self.max_retries = 3
self.retry_delay = 1
# Branch naming pattern
self.branch_prefix = "staging"
# Initialize health check integration (will be connected later)
self.health_checker = None
self.logger.info(f"Staging workflow initialized for {self.project_path}")
def create_staging_branch(self, improvement_type: str, description: str) -> Dict[str, Any]:
"""
Create a staging branch for improvements.
Args:
improvement_type: Type of improvement (e.g., 'optimization', 'feature', 'bugfix')
description: Description of improvement
Returns:
Dictionary with branch information
Raises:
StagingWorkflowError: If branch creation fails
"""
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
# Sanitize description for branch name
description_safe = "".join(c for c in description[:20] if c.isalnum() or c in "-_").lower()
branch_name = f"{self.branch_prefix}/{improvement_type}-{timestamp}-{description_safe}"
try:
# Ensure we're on main/develop branch
self._ensure_main_branch()
# Check if branch already exists
if branch_name in [ref.name for ref in self.repo.refs]:
self.logger.warning(f"Branch {branch_name} already exists")
existing_branch = self.repo.refs[branch_name]
return {
"branch_name": branch_name,
"branch": existing_branch,
"created": False,
"message": f"Branch {branch_name} already exists",
}
# Create new branch
current_branch = self.repo.active_branch
new_branch = self.repo.create_head(branch_name, current_branch.commit.hexsha)
# Simple metadata handling - just log for now
self.logger.info(f"Branch metadata: type={improvement_type}, desc={description}")
self.logger.info(f"Created staging branch: {branch_name}")
return {
"branch_name": branch_name,
"branch": new_branch,
"created": True,
"timestamp": timestamp,
"improvement_type": improvement_type,
"description": description,
"message": f"Created staging branch {branch_name}",
}
except (GitError, GitCommandError) as e:
raise StagingWorkflowError(f"Failed to create branch {branch_name}: {e}")
def switch_to_branch(self, branch_name: str) -> Dict[str, Any]:
"""
Safely switch to specified branch.
Args:
branch_name: Name of branch to switch to
Returns:
Dictionary with switch result
Raises:
StagingWorkflowError: If switch fails
"""
try:
# Check for uncommitted changes
if self.repo.is_dirty(untracked_files=True):
return {
"success": False,
"branch_name": branch_name,
"message": "Working directory has uncommitted changes. Commit or stash first.",
"uncommitted": True,
}
# Verify branch exists
if branch_name not in [ref.name for ref in self.repo.refs]:
return {
"success": False,
"branch_name": branch_name,
"message": f"Branch {branch_name} does not exist",
"exists": False,
}
# Switch to branch
branch = self.repo.refs[branch_name]
branch.checkout()
self.logger.info(f"Switched to branch: {branch_name}")
return {
"success": True,
"branch_name": branch_name,
"message": f"Switched to {branch_name}",
"current_commit": str(self.repo.active_branch.commit),
}
except (GitError, GitCommandError) as e:
raise StagingWorkflowError(f"Failed to switch to branch {branch_name}: {e}")
def get_active_staging_branches(self) -> List[Dict[str, Any]]:
"""
List all staging branches with metadata.
Returns:
List of dictionaries with branch information
"""
staging_branches = []
current_time = datetime.now()
for ref in self.repo.refs:
if ref.name.startswith(self.branch_prefix + "/"):
try:
# Get branch age
commit_time = datetime.fromtimestamp(ref.commit.committed_date)
age = current_time - commit_time
# Check if branch is stale (> 7 days)
is_stale = age > timedelta(days=7)
# Simple metadata for now
metadata = {
"improvement_type": "unknown",
"description": "no description",
"created": "unknown",
}
staging_branches.append(
{
"name": ref.name,
"commit": str(ref.commit),
"commit_message": ref.commit.message.strip(),
"created": commit_time.isoformat(),
"age_days": age.days,
"age_hours": age.total_seconds() / 3600,
"is_stale": is_stale,
"metadata": metadata,
"is_current": ref.name == self.repo.active_branch.name,
}
)
except Exception as e:
self.logger.warning(f"Error processing branch {ref.name}: {e}")
continue
# Sort by creation time (newest first)
staging_branches.sort(key=lambda x: x["created"], reverse=True)
return staging_branches
def validate_branch_state(self, branch_name: str) -> Dict[str, Any]:
"""
Validate branch state for safe merging.
Args:
branch_name: Name of branch to validate
Returns:
Dictionary with validation results
"""
try:
if branch_name not in [ref.name for ref in self.repo.refs]:
return {
"valid": False,
"branch_name": branch_name,
"issues": [f"Branch {branch_name} does not exist"],
"can_merge": False,
}
# Switch to branch temporarily if not already there
original_branch = self.repo.active_branch.name
if original_branch != branch_name:
switch_result = self.switch_to_branch(branch_name)
if not switch_result["success"]:
return {
"valid": False,
"branch_name": branch_name,
"issues": [switch_result["message"]],
"can_merge": False,
}
issues = []
# Check for uncommitted changes
if self.repo.is_dirty(untracked_files=True):
issues.append("Working directory has uncommitted changes")
# Check for merge conflicts with main branch
try:
# Try to simulate merge without actually merging
main_branch = self._get_main_branch()
if main_branch and branch_name != main_branch:
merge_base = self.repo.merge_base(branch_name, main_branch)
if not merge_base:
issues.append("No common ancestor with main branch")
except Exception as e:
issues.append(f"Cannot determine merge compatibility: {e}")
# Switch back to original branch
if original_branch != branch_name:
self.switch_to_branch(original_branch)
return {
"valid": len(issues) == 0,
"branch_name": branch_name,
"issues": issues,
"can_merge": len(issues) == 0,
"metadata": {"improvement_type": "unknown", "description": "no description"},
}
except Exception as e:
return {
"valid": False,
"branch_name": branch_name,
"issues": [f"Validation failed: {e}"],
"can_merge": False,
}
def cleanup_staging_branch(
self, branch_name: str, keep_if_failed: bool = False
) -> Dict[str, Any]:
"""
Clean up staging branch after merge or when abandoned.
Args:
branch_name: Name of branch to cleanup
keep_if_failed: Keep branch if validation failed
Returns:
Dictionary with cleanup result
"""
try:
if branch_name not in [ref.name for ref in self.repo.refs]:
return {
"success": False,
"branch_name": branch_name,
"message": f"Branch {branch_name} does not exist",
}
# Check validation result if keep_if_failed is True
if keep_if_failed:
validation = self.validate_branch_state(branch_name)
if not validation["can_merge"]:
return {
"success": False,
"branch_name": branch_name,
"message": "Keeping branch due to validation failures",
"validation": validation,
}
# Don't delete current branch
if branch_name == self.repo.active_branch.name:
return {
"success": False,
"branch_name": branch_name,
"message": "Cannot delete currently active branch",
}
# Delete branch
self.repo.delete_head(branch_name, force=True)
self.logger.info(f"Cleaned up staging branch: {branch_name}")
return {
"success": True,
"branch_name": branch_name,
"message": f"Deleted staging branch {branch_name}",
"deleted": True,
}
except (GitError, GitCommandError) as e:
self.logger.error(f"Failed to cleanup branch {branch_name}: {e}")
return {
"success": False,
"branch_name": branch_name,
"message": f"Failed to delete branch: {e}",
"deleted": False,
}
def cleanup_old_staging_branches(self, days_old: int = 7) -> Dict[str, Any]:
"""
Clean up old staging branches.
Args:
days_old: Age threshold in days
Returns:
Dictionary with cleanup results
"""
staging_branches = self.get_active_staging_branches()
old_branches = [b for b in staging_branches if b["age_days"] > days_old]
cleanup_results = []
for branch_info in old_branches:
result = self.cleanup_staging_branch(branch_info["name"])
cleanup_results.append(result)
successful = sum(1 for r in cleanup_results if r["success"])
return {
"total_old_branches": len(old_branches),
"cleaned_up": successful,
"failed": len(old_branches) - successful,
"results": cleanup_results,
}
def _ensure_main_branch(self) -> None:
"""Ensure we're on main or develop branch."""
current = self.repo.active_branch.name
main_branch = self._get_main_branch()
if main_branch and current != main_branch:
try:
self.repo.refs[main_branch].checkout()
except (GitError, GitCommandError) as e:
self.logger.warning(f"Cannot switch to {main_branch}: {e}")
def _get_main_branch(self) -> Optional[str]:
"""Get main/develop branch name."""
for branch_name in ["main", "develop", "master"]:
if branch_name in [ref.name for ref in self.repo.refs]:
return branch_name
return None
def set_health_checker(self, health_checker) -> None:
"""Set health checker integration."""
self.health_checker = health_checker

View File

@@ -1,95 +0,0 @@
"""
Mai Memory Module
Provides persistent storage and retrieval of conversations
with semantic search capabilities.
This module serves as the foundation for Mai's memory system,
enabling conversation retention and intelligent context retrieval.
"""
# Version information
__version__ = "0.1.0"
__author__ = "Mai Team"
# Core exports
from .storage import MemoryStorage
# Optional exports (may not be available if dependencies missing)
try:
from .storage import (
MemoryStorageError,
VectorSearchError,
DatabaseConnectionError,
)
__all__ = [
"MemoryStorage",
"MemoryStorageError",
"VectorSearchError",
"DatabaseConnectionError",
]
except ImportError:
__all__ = ["MemoryStorage"]
# Module metadata
__module_info__ = {
"name": "Mai Memory Module",
"description": "Persistent memory storage with semantic search",
"version": __version__,
"features": {
"sqlite_storage": True,
"vector_search": "sqlite-vec" in globals(),
"embeddings": "sentence-transformers" in globals(),
"fallback_search": True,
},
"dependencies": {
"required": ["sqlite3"],
"optional": {
"sqlite-vec": "Vector similarity search",
"sentence-transformers": "Text embeddings",
},
},
}
def get_module_info():
"""Get module information and capabilities."""
return __module_info__
def is_vector_search_available() -> bool:
"""Check if vector search is available."""
try:
import sqlite_vec
from sentence_transformers import SentenceTransformer
return True
except ImportError:
return False
def is_embeddings_available() -> bool:
"""Check if text embeddings are available."""
try:
from sentence_transformers import SentenceTransformer
return True
except ImportError:
return False
def get_memory_storage(*args, **kwargs):
"""
Factory function to create MemoryStorage instances.
Args:
*args: Positional arguments to pass to MemoryStorage
**kwargs: Keyword arguments to pass to MemoryStorage
Returns:
Configured MemoryStorage instance
"""
from .storage import MemoryStorage
return MemoryStorage(*args, **kwargs)

View File

@@ -1,780 +0,0 @@
"""
Memory Compression Implementation for Mai
Intelligent conversation compression with AI-powered summarization
and pattern preservation for long-term memory efficiency.
"""
import logging
import json
import hashlib
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from pathlib import Path
# Import Mai components
try:
from src.mai.core.exceptions import (
MaiError,
ContextError,
create_error_context,
)
from src.mai.core.config import get_config
from src.mai.model.ollama_client import OllamaClient
from src.mai.memory.storage import MemoryStorage
except ImportError:
# Define fallbacks if modules not available
class MaiError(Exception):
pass
class ContextError(MaiError):
pass
def create_error_context(component: str, operation: str, **data):
return {"component": component, "operation": operation, "data": data}
def get_config():
return None
class MemoryStorage:
def __init__(self, *args, **kwargs):
pass
def retrieve_conversation(self, conversation_id: str) -> Optional[Dict[str, Any]]:
return None
def update_conversation(self, conversation_id: str, **kwargs) -> bool:
return True
logger = logging.getLogger(__name__)
class MemoryCompressionError(ContextError):
"""Memory compression specific errors."""
def __init__(self, message: str, conversation_id: str = None, **kwargs):
context = create_error_context(
component="memory_compressor",
operation="compression",
conversation_id=conversation_id,
**kwargs,
)
super().__init__(message, context=context)
self.conversation_id = conversation_id
@dataclass
class CompressionThresholds:
"""Configuration for compression triggers."""
message_count: int = 50
age_days: int = 30
memory_limit_mb: int = 500
def should_compress(self, conversation: Dict[str, Any], current_memory_mb: float) -> bool:
"""
Check if conversation should be compressed.
Args:
conversation: Conversation data
current_memory_mb: Current memory usage in MB
Returns:
True if compression should be triggered
"""
# Check message count
message_count = len(conversation.get("messages", []))
if message_count >= self.message_count:
return True
# Check age
try:
created_at = datetime.fromisoformat(conversation.get("created_at", ""))
age_days = (datetime.now() - created_at).days
if age_days >= self.age_days:
return True
except (ValueError, TypeError):
pass
# Check memory limit
if current_memory_mb >= self.memory_limit_mb:
return True
return False
@dataclass
class CompressionResult:
"""Result of compression operation."""
success: bool
original_messages: int
compressed_messages: int
compression_ratio: float
summary: str
patterns: List[Dict[str, Any]] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
error: Optional[str] = None
class MemoryCompressor:
"""
Intelligent conversation compression with AI summarization.
Automatically compresses growing conversations while preserving
important information, user patterns, and conversation continuity.
"""
def __init__(
self,
storage: Optional[MemoryStorage] = None,
ollama_client: Optional[OllamaClient] = None,
config: Optional[Dict[str, Any]] = None,
):
"""
Initialize memory compressor.
Args:
storage: Memory storage instance
ollama_client: Ollama client for AI summarization
config: Compression configuration
"""
self.storage = storage or MemoryStorage()
self.ollama_client = ollama_client or OllamaClient()
# Load configuration
self.config = config or self._load_default_config()
self.thresholds = CompressionThresholds(**self.config.get("thresholds", {}))
# Compression history tracking
self.compression_history: Dict[str, List[Dict[str, Any]]] = {}
logger.info("MemoryCompressor initialized")
def _load_default_config(self) -> Dict[str, Any]:
"""Load default compression configuration."""
return {
"thresholds": {"message_count": 50, "age_days": 30, "memory_limit_mb": 500},
"summarization": {
"model": "llama2",
"preserve_elements": ["preferences", "decisions", "patterns", "key_facts"],
"min_quality_score": 0.7,
},
"adaptive_weighting": {
"importance_decay_days": 90,
"pattern_weight": 1.5,
"technical_weight": 1.2,
},
}
def check_compression_needed(self, conversation_id: str) -> bool:
"""
Check if conversation needs compression.
Args:
conversation_id: ID of conversation to check
Returns:
True if compression is needed
"""
try:
# Get conversation data
conversation = self.storage.retrieve_conversation(conversation_id)
if not conversation:
logger.warning(f"Conversation {conversation_id} not found")
return False
# Get current memory usage
storage_stats = self.storage.get_storage_stats()
current_memory_mb = storage_stats.get("database_size_mb", 0)
# Check thresholds
return self.thresholds.should_compress(conversation, current_memory_mb)
except Exception as e:
logger.error(f"Error checking compression need for {conversation_id}: {e}")
return False
def compress_conversation(self, conversation_id: str) -> CompressionResult:
"""
Compress a conversation using AI summarization.
Args:
conversation_id: ID of conversation to compress
Returns:
CompressionResult with operation details
"""
try:
# Get conversation data
conversation = self.storage.retrieve_conversation(conversation_id)
if not conversation:
return CompressionResult(
success=False,
original_messages=0,
compressed_messages=0,
compression_ratio=0.0,
summary="",
error=f"Conversation {conversation_id} not found",
)
messages = conversation.get("messages", [])
original_count = len(messages)
if original_count < self.thresholds.message_count:
return CompressionResult(
success=False,
original_messages=original_count,
compressed_messages=original_count,
compression_ratio=1.0,
summary="",
error="Conversation below compression threshold",
)
# Analyze conversation for compression strategy
compression_strategy = self._analyze_conversation(messages)
# Generate AI summary
summary = self._generate_summary(messages, compression_strategy)
# Extract patterns
patterns = self._extract_patterns(messages, compression_strategy)
# Create compressed conversation structure
compressed_messages = self._create_compressed_structure(
messages, summary, patterns, compression_strategy
)
# Update conversation in storage
success = self._update_compressed_conversation(
conversation_id, compressed_messages, summary, patterns
)
if not success:
return CompressionResult(
success=False,
original_messages=original_count,
compressed_messages=original_count,
compression_ratio=1.0,
summary=summary,
error="Failed to update compressed conversation",
)
# Calculate compression ratio
compressed_count = len(compressed_messages)
compression_ratio = compressed_count / original_count if original_count > 0 else 1.0
# Track compression history
self._track_compression(
conversation_id,
{
"timestamp": datetime.now().isoformat(),
"original_messages": original_count,
"compressed_messages": compressed_count,
"compression_ratio": compression_ratio,
"strategy": compression_strategy,
},
)
logger.info(
f"Compressed conversation {conversation_id}: {original_count}{compressed_count} messages"
)
return CompressionResult(
success=True,
original_messages=original_count,
compressed_messages=compressed_count,
compression_ratio=compression_ratio,
summary=summary,
patterns=patterns,
metadata={
"strategy": compression_strategy,
"timestamp": datetime.now().isoformat(),
},
)
except Exception as e:
logger.error(f"Error compressing conversation {conversation_id}: {e}")
return CompressionResult(
success=False,
original_messages=0,
compressed_messages=0,
compression_ratio=0.0,
summary="",
error=str(e),
)
def _analyze_conversation(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Analyze conversation to determine compression strategy.
Args:
messages: List of conversation messages
Returns:
Compression strategy dictionary
"""
strategy = {
"keep_recent_count": 10, # Keep most recent messages
"importance_weights": {},
"conversation_type": "general",
"key_topics": [],
"user_preferences": [],
}
# Analyze message patterns
user_messages = [m for m in messages if m.get("role") == "user"]
assistant_messages = [m for m in messages if m.get("role") == "assistant"]
# Detect conversation type
if self._is_technical_conversation(messages):
strategy["conversation_type"] = "technical"
strategy["keep_recent_count"] = 15 # Keep more technical context
elif self._is_planning_conversation(messages):
strategy["conversation_type"] = "planning"
strategy["keep_recent_count"] = 12
# Identify key topics (simple keyword extraction)
all_content = " ".join([m.get("content", "") for m in messages])
strategy["key_topics"] = self._extract_key_topics(all_content)
# Calculate importance weights based on recency and content
for i, message in enumerate(messages):
# More recent messages get higher weight
recency_weight = (i + 1) / len(messages)
# Content-based weighting
content_weight = 1.0
content = message.get("content", "").lower()
# Boost weight for messages containing key information
if any(
keyword in content
for keyword in ["prefer", "want", "should", "decide", "important"]
):
content_weight *= 1.5
# Technical content gets boost in technical conversations
if strategy["conversation_type"] == "technical":
if any(
keyword in content
for keyword in ["code", "function", "implement", "fix", "error"]
):
content_weight *= 1.2
strategy["importance_weights"][message.get("id", f"msg_{i}")] = (
recency_weight * content_weight
)
return strategy
def _is_technical_conversation(self, messages: List[Dict[str, Any]]) -> bool:
"""Detect if conversation is technical in nature."""
technical_keywords = [
"code",
"function",
"implement",
"debug",
"error",
"fix",
"programming",
"development",
"api",
"database",
"algorithm",
]
tech_message_count = 0
total_messages = len(messages)
for message in messages:
content = message.get("content", "").lower()
if any(keyword in content for keyword in technical_keywords):
tech_message_count += 1
return (tech_message_count / total_messages) > 0.3 if total_messages > 0 else False
def _is_planning_conversation(self, messages: List[Dict[str, Any]]) -> bool:
"""Detect if conversation is about planning."""
planning_keywords = [
"plan",
"schedule",
"deadline",
"task",
"goal",
"objective",
"timeline",
"milestone",
"strategy",
"roadmap",
]
plan_message_count = 0
total_messages = len(messages)
for message in messages:
content = message.get("content", "").lower()
if any(keyword in content for keyword in planning_keywords):
plan_message_count += 1
return (plan_message_count / total_messages) > 0.25 if total_messages > 0 else False
def _extract_key_topics(self, content: str) -> List[str]:
"""Extract key topics from content (simple implementation)."""
# This is a simplified topic extraction
# In a real implementation, you might use NLP techniques
common_topics = [
"development",
"design",
"testing",
"deployment",
"maintenance",
"security",
"performance",
"user interface",
"database",
"api",
]
topics = []
content_lower = content.lower()
for topic in common_topics:
if topic in content_lower:
topics.append(topic)
return topics[:5] # Return top 5 topics
def _generate_summary(self, messages: List[Dict[str, Any]], strategy: Dict[str, Any]) -> str:
"""
Generate AI summary of conversation.
Args:
messages: Messages to summarize
strategy: Compression strategy information
Returns:
Generated summary text
"""
try:
# Prepare summarization prompt
preserve_elements = self.config.get("summarization", {}).get("preserve_elements", [])
prompt = f"""Please summarize this conversation while preserving important information:
Conversation type: {strategy.get("conversation_type", "general")}
Key topics: {", ".join(strategy.get("key_topics", []))}
Please preserve:
- {", ".join(preserve_elements)}
Create a concise summary that maintains conversation continuity and captures the most important points.
Conversation:
"""
# Add conversation context (limit to avoid token limits)
for message in messages[-30:]: # Include last 30 messages for context
role = message.get("role", "unknown")
content = message.get("content", "")[:500] # Truncate long messages
prompt += f"\n{role}: {content}"
prompt += "\n\nSummary:"
# Generate summary using Ollama
model = self.config.get("summarization", {}).get("model", "llama2")
summary = self.ollama_client.generate_response(prompt, model)
# Clean up summary
summary = summary.strip()
if len(summary) > 1000:
summary = summary[:1000] + "..."
return summary
except Exception as e:
logger.error(f"Error generating summary: {e}")
# Fallback to simple summary
return f"Conversation with {len(messages)} messages about {', '.join(strategy.get('key_topics', ['various topics']))}."
def _extract_patterns(
self, messages: List[Dict[str, Any]], strategy: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""
Extract patterns from conversation for future learning.
Args:
messages: Messages to analyze
strategy: Compression strategy
Returns:
List of extracted patterns
"""
patterns = []
try:
# Extract user preferences
user_preferences = self._extract_user_preferences(messages)
patterns.extend(user_preferences)
# Extract interaction patterns
interaction_patterns = self._extract_interaction_patterns(messages)
patterns.extend(interaction_patterns)
# Extract topic preferences
topic_patterns = self._extract_topic_patterns(messages, strategy)
patterns.extend(topic_patterns)
except Exception as e:
logger.error(f"Error extracting patterns: {e}")
return patterns
def _extract_user_preferences(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Extract user preferences from messages."""
preferences = []
preference_keywords = {
"like": "positive_preference",
"prefer": "preference",
"want": "desire",
"don't like": "negative_preference",
"avoid": "avoidance",
"should": "expectation",
}
for message in messages:
if message.get("role") != "user":
continue
content = message.get("content", "").lower()
for keyword, pref_type in preference_keywords.items():
if keyword in content:
# Extract the preference context (simplified)
preferences.append(
{
"type": pref_type,
"keyword": keyword,
"context": content[:200], # Truncate for storage
"timestamp": message.get("timestamp"),
"confidence": 0.7, # Simplified confidence score
}
)
return preferences
def _extract_interaction_patterns(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Extract interaction patterns from conversation."""
patterns = []
# Analyze response patterns
user_messages = [m for m in messages if m.get("role") == "user"]
assistant_messages = [m for m in messages if m.get("role") == "assistant"]
if len(user_messages) > 0 and len(assistant_messages) > 0:
# Calculate average message lengths
avg_user_length = sum(len(m.get("content", "")) for m in user_messages) / len(
user_messages
)
avg_assistant_length = sum(len(m.get("content", "")) for m in assistant_messages) / len(
assistant_messages
)
patterns.append(
{
"type": "communication_style",
"avg_user_message_length": avg_user_length,
"avg_assistant_message_length": avg_assistant_length,
"message_count": len(messages),
"user_to_assistant_ratio": len(user_messages) / len(assistant_messages),
}
)
return patterns
def _extract_topic_patterns(
self, messages: List[Dict[str, Any]], strategy: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Extract topic preferences from conversation."""
patterns = []
key_topics = strategy.get("key_topics", [])
if key_topics:
patterns.append(
{
"type": "topic_preference",
"topics": key_topics,
"conversation_type": strategy.get("conversation_type", "general"),
"message_count": len(messages),
}
)
return patterns
def _create_compressed_structure(
self,
messages: List[Dict[str, Any]],
summary: str,
patterns: List[Dict[str, Any]],
strategy: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""
Create compressed conversation structure.
Args:
messages: Original messages
summary: Generated summary
patterns: Extracted patterns
strategy: Compression strategy
Returns:
Compressed message list
"""
compressed = []
# Add compression marker as system message
compressed.append(
{
"id": "compression_marker",
"role": "system",
"content": f"[COMPRESSED] Original conversation had {len(messages)} messages",
"timestamp": datetime.now().isoformat(),
"token_count": 0,
}
)
# Add summary
compressed.append(
{
"id": "conversation_summary",
"role": "assistant",
"content": f"Summary: {summary}",
"timestamp": datetime.now().isoformat(),
"token_count": len(summary.split()), # Rough estimate
}
)
# Add extracted patterns if any
if patterns:
patterns_text = "Key patterns extracted:\n"
for pattern in patterns[:5]: # Limit to 5 patterns
patterns_text += f"- {pattern.get('type', 'unknown')}: {str(pattern.get('context', pattern))[:100]}\n"
compressed.append(
{
"id": "extracted_patterns",
"role": "assistant",
"content": patterns_text,
"timestamp": datetime.now().isoformat(),
"token_count": len(patterns_text.split()),
}
)
# Keep most recent messages based on strategy
keep_count = strategy.get("keep_recent_count", 10)
recent_messages = messages[-keep_count:] if len(messages) > keep_count else messages
for message in recent_messages:
compressed.append(
{
"id": message.get("id", f"compressed_{len(compressed)}"),
"role": message.get("role"),
"content": message.get("content"),
"timestamp": message.get("timestamp"),
"token_count": message.get("token_count", 0),
}
)
return compressed
def _update_compressed_conversation(
self,
conversation_id: str,
compressed_messages: List[Dict[str, Any]],
summary: str,
patterns: List[Dict[str, Any]],
) -> bool:
"""
Update conversation with compressed content.
Args:
conversation_id: Conversation ID
compressed_messages: Compressed message list
summary: Generated summary
patterns: Extracted patterns
Returns:
True if update successful
"""
try:
# This would use the storage interface to update the conversation
# For now, we'll simulate the update
# In a real implementation, you would:
# 1. Update the messages in the database
# 2. Store compression metadata
# 3. Update conversation metadata
logger.info(f"Updated conversation {conversation_id} with compressed content")
return True
except Exception as e:
logger.error(f"Error updating compressed conversation: {e}")
return False
def _track_compression(self, conversation_id: str, compression_data: Dict[str, Any]) -> None:
"""
Track compression history for analytics.
Args:
conversation_id: Conversation ID
compression_data: Compression operation data
"""
if conversation_id not in self.compression_history:
self.compression_history[conversation_id] = []
self.compression_history[conversation_id].append(compression_data)
# Limit history size
if len(self.compression_history[conversation_id]) > 10:
self.compression_history[conversation_id] = self.compression_history[conversation_id][
-10:
]
def get_compression_stats(self) -> Dict[str, Any]:
"""
Get compression statistics.
Returns:
Dictionary with compression statistics
"""
total_compressions = sum(len(history) for history in self.compression_history.values())
if total_compressions == 0:
return {
"total_compressions": 0,
"average_compression_ratio": 0.0,
"conversations_compressed": 0,
}
# Calculate average compression ratio
total_ratio = 0.0
ratio_count = 0
for conversation_id, history in self.compression_history.items():
for compression in history:
ratio = compression.get("compression_ratio", 1.0)
total_ratio += ratio
ratio_count += 1
avg_ratio = total_ratio / ratio_count if ratio_count > 0 else 1.0
return {
"total_compressions": total_compressions,
"average_compression_ratio": avg_ratio,
"conversations_compressed": len(self.compression_history),
"compression_history": dict(self.compression_history),
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,822 +0,0 @@
"""
Memory Storage Implementation for Mai
Provides SQLite-based persistent storage with vector similarity search
for conversation retention and semantic retrieval.
"""
import os
import sqlite3
import json
import logging
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime
from pathlib import Path
# Import dependencies
try:
import sqlite_vec # type: ignore
except ImportError:
# Fallback if sqlite-vec not installed
sqlite_vec = None
try:
from sentence_transformers import SentenceTransformer
except ImportError:
# Fallback if sentence-transformers not installed
SentenceTransformer = None
# Import Mai components
try:
from src.mai.core.exceptions import (
MaiError,
ContextError,
create_error_context,
)
from src.mai.core.config import get_config
except ImportError:
# Define fallbacks if modules not available
class MaiError(Exception):
pass
class ContextError(MaiError):
pass
def create_error_context(component: str, operation: str, **data):
return {"component": component, "operation": operation, "data": data}
def get_config():
return None
logger = logging.getLogger(__name__)
class MemoryStorageError(ContextError):
"""Memory storage specific errors."""
def __init__(self, message: str, operation: str = None, **kwargs):
context = create_error_context(
component="memory_storage", operation=operation or "storage_operation", **kwargs
)
super().__init__(message, context=context)
self.operation = operation
class VectorSearchError(MemoryStorageError):
"""Vector similarity search errors."""
def __init__(self, query: str, error_details: str = None):
message = f"Vector search failed for query: '{query}'"
if error_details:
message += f": {error_details}"
super().__init__(
message=message, operation="vector_search", query=query, error_details=error_details
)
class DatabaseConnectionError(MemoryStorageError):
"""Database connection and operation errors."""
def __init__(self, db_path: str, error_details: str = None):
message = f"Database connection error: {db_path}"
if error_details:
message += f": {error_details}"
super().__init__(
message=message,
operation="database_connection",
db_path=db_path,
error_details=error_details,
)
class MemoryStorage:
"""
SQLite-based memory storage with vector similarity search.
Handles persistent storage of conversations, messages, and embeddings
with semantic search capabilities using sqlite-vec extension.
"""
def __init__(self, db_path: Optional[str] = None, embedding_model: str = "all-MiniLM-L6-v2"):
"""
Initialize memory storage with database and embedding model.
Args:
db_path: Path to SQLite database file (default: ./data/mai_memory.db)
embedding_model: Name of sentence-transformers model to use
"""
# Set database path
if db_path is None:
# Default to ./data/mai_memory.db
db_path = os.path.join(os.getcwd(), "data", "mai_memory.db")
self.db_path = Path(db_path)
self.embedding_model_name = embedding_model
# Ensure database directory exists
self.db_path.parent.mkdir(parents=True, exist_ok=True)
# Initialize components
self._db: Optional[sqlite3.Connection] = None
self._embedding_model: Optional[SentenceTransformer] = None
self._embedding_dim: Optional[int] = None
self._config = get_config()
# Initialize embedding model first (needed for database schema)
self._initialize_embedding_model()
# Then initialize database
self._initialize_database()
logger.info(f"MemoryStorage initialized with database: {self.db_path}")
def _initialize_database(self) -> None:
"""Initialize SQLite database with schema and vector extension."""
try:
# Connect to database
self._db = sqlite3.connect(str(self.db_path))
self._db.row_factory = sqlite3.Row # Enable dict-like row access
# Enable foreign keys
self._db.execute("PRAGMA foreign_keys = ON")
# Load sqlite-vec extension if available
if sqlite_vec is not None:
try:
self._db.enable_load_extension(True)
# Try to load the full path to vec0.so
vec_path = sqlite_vec.__file__.replace("__init__.py", "vec0.so")
self._db.load_extension(vec_path)
logger.info("sqlite-vec extension loaded successfully")
self._vector_enabled = True
except Exception as e:
logger.warning(f"Failed to load sqlite-vec extension: {e}")
# Try fallback with just extension name
try:
self._db.load_extension("vec0")
logger.info("sqlite-vec extension loaded successfully (fallback)")
self._vector_enabled = True
except Exception as e2:
logger.warning(f"Failed to load sqlite-vec extension (fallback): {e2}")
self._vector_enabled = False
else:
logger.warning("sqlite-vec not available - vector features disabled")
self._vector_enabled = False
# Create tables
self._create_tables()
# Verify schema
self._verify_schema()
except Exception as e:
raise DatabaseConnectionError(db_path=str(self.db_path), error_details=str(e))
def _initialize_embedding_model(self) -> None:
"""Initialize sentence-transformers embedding model."""
try:
if SentenceTransformer is not None:
# Load embedding model (download if needed)
logger.info(f"Loading embedding model: {self.embedding_model_name}")
self._embedding_model = SentenceTransformer(self.embedding_model_name)
# Test embedding generation
test_embedding = self._embedding_model.encode("test")
self._embedding_dim = len(test_embedding)
logger.info(
f"Embedding model loaded: {self.embedding_model_name} (dim: {self._embedding_dim})"
)
else:
logger.warning("sentence-transformers not available - embeddings disabled")
self._embedding_model = None
self._embedding_dim = None
except Exception as e:
logger.error(f"Failed to initialize embedding model: {e}")
self._embedding_model = None
self._embedding_dim = None
def _create_tables(self) -> None:
"""Create database schema for conversations, messages, and embeddings."""
cursor = self._db.cursor()
try:
# Conversations table
cursor.execute("""
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
metadata TEXT DEFAULT '{}'
)
""")
# Messages table
cursor.execute("""
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system')),
content TEXT NOT NULL,
timestamp TEXT NOT NULL,
token_count INTEGER DEFAULT 0,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
)
""")
# Vector embeddings table (if sqlite-vec available)
if self._vector_enabled and self._embedding_dim:
cursor.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS message_embeddings
USING vec0(
embedding float[{self._embedding_dim}]
)
""")
# Regular table for embedding metadata
cursor.execute("""
CREATE TABLE IF NOT EXISTS embedding_metadata (
rowid INTEGER PRIMARY KEY,
message_id TEXT NOT NULL,
conversation_id TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
)
""")
# Create indexes for performance
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages(conversation_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON messages(timestamp)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_conversations_updated ON conversations(updated_at)"
)
# Commit schema changes
self._db.commit()
logger.info("Database schema created successfully")
except Exception as e:
self._db.rollback()
raise MemoryStorageError(
message=f"Failed to create database schema: {e}", operation="create_schema"
)
finally:
cursor.close()
def _verify_schema(self) -> None:
"""Verify that database schema is correct and up-to-date."""
cursor = self._db.cursor()
try:
# Check if required tables exist
cursor.execute("""
SELECT name FROM sqlite_master
WHERE type='table' AND name IN ('conversations', 'messages')
""")
required_tables = [row[0] for row in cursor.fetchall()]
if len(required_tables) != 2:
raise MemoryStorageError(
message="Required tables missing from database", operation="verify_schema"
)
# Check vector table if vector search is enabled
if self._vector_enabled:
cursor.execute("""
SELECT name FROM sqlite_master
WHERE type='table' AND name='message_embeddings'
""")
vector_tables = [row[0] for row in cursor.fetchall()]
if not vector_tables:
logger.warning("Vector table not found - vector search disabled")
self._vector_enabled = False
logger.info("Database schema verification passed")
except Exception as e:
raise MemoryStorageError(
message=f"Schema verification failed: {e}", operation="verify_schema"
)
finally:
cursor.close()
def store_conversation(
self,
conversation_id: str,
title: str,
messages: List[Dict[str, Any]],
metadata: Optional[Dict[str, Any]] = None,
) -> bool:
"""
Store a complete conversation with all messages.
Args:
conversation_id: Unique identifier for the conversation
title: Human-readable title for the conversation
messages: List of messages with 'role', 'content', and optional 'timestamp'
metadata: Additional metadata to store with conversation
Returns:
True if stored successfully
Raises:
MemoryStorageError: If storage operation fails
"""
if self._db is None:
raise DatabaseConnectionError(db_path=str(self.db_path))
cursor = self._db.cursor()
now = datetime.now().isoformat()
try:
# Insert conversation
cursor.execute(
"""
INSERT OR REPLACE INTO conversations
(id, title, created_at, updated_at, metadata)
VALUES (?, ?, ?, ?, ?)
""",
[conversation_id, title, now, now, json.dumps(metadata or {})],
)
# Insert messages
for i, message in enumerate(messages):
message_id = f"{conversation_id}_{i}"
role = message.get("role", "user")
content = message.get("content", "")
timestamp = message.get("timestamp", now)
# Basic validation
if role not in ["user", "assistant", "system"]:
role = "user"
cursor.execute(
"""
INSERT OR REPLACE INTO messages
(id, conversation_id, role, content, timestamp)
VALUES (?, ?, ?, ?, ?)
""",
[message_id, conversation_id, role, content, timestamp],
)
# Generate and store embedding if available
if self._embedding_model and self._vector_enabled:
try:
embedding = self._embedding_model.encode(content)
# Store embedding in vector table
cursor.execute(
"""
INSERT INTO message_embeddings (rowid, embedding)
VALUES (?, ?)
""",
[len(content), embedding.tolist()],
)
# Store embedding metadata
vector_rowid = cursor.lastrowid
cursor.execute(
"""
INSERT INTO embedding_metadata
(rowid, message_id, conversation_id, created_at)
VALUES (?, ?, ?, ?)
""",
[vector_rowid, message_id, conversation_id, now],
)
except Exception as e:
logger.warning(
f"Failed to generate embedding for message {message_id}: {e}"
)
# Continue without embedding - don't fail the whole operation
self._db.commit()
logger.info(f"Stored conversation '{conversation_id}' with {len(messages)} messages")
return True
except Exception as e:
self._db.rollback()
raise MemoryStorageError(
message=f"Failed to store conversation: {e}",
operation="store_conversation",
conversation_id=conversation_id,
)
finally:
cursor.close()
def retrieve_conversation(self, conversation_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieve a complete conversation by ID.
Args:
conversation_id: ID of conversation to retrieve
Returns:
Dictionary with conversation data or None if not found
Raises:
MemoryStorageError: If retrieval operation fails
"""
if self._db is None:
raise DatabaseConnectionError(db_path=str(self.db_path))
cursor = self._db.cursor()
try:
# Get conversation info
cursor.execute(
"""
SELECT id, title, created_at, updated_at, metadata
FROM conversations
WHERE id = ?
""",
[conversation_id],
)
conversation_row = cursor.fetchone()
if not conversation_row:
return None
# Get messages
cursor.execute(
"""
SELECT id, role, content, timestamp, token_count
FROM messages
WHERE conversation_id = ?
ORDER BY timestamp
""",
[conversation_id],
)
message_rows = cursor.fetchall()
# Build result
conversation = {
"id": conversation_row["id"],
"title": conversation_row["title"],
"created_at": conversation_row["created_at"],
"updated_at": conversation_row["updated_at"],
"metadata": json.loads(conversation_row["metadata"]),
"messages": [
{
"id": msg["id"],
"role": msg["role"],
"content": msg["content"],
"timestamp": msg["timestamp"],
"token_count": msg["token_count"],
}
for msg in message_rows
],
}
logger.debug(
f"Retrieved conversation '{conversation_id}' with {len(message_rows)} messages"
)
return conversation
except Exception as e:
raise MemoryStorageError(
message=f"Failed to retrieve conversation: {e}",
operation="retrieve_conversation",
conversation_id=conversation_id,
)
finally:
cursor.close()
def search_conversations(
self, query: str, limit: int = 5, include_content: bool = False
) -> List[Dict[str, Any]]:
"""
Search conversations using semantic similarity.
Args:
query: Search query text
limit: Maximum number of results to return
include_content: Whether to include full message content in results
Returns:
List of matching conversations with similarity scores
Raises:
VectorSearchError: If search operation fails
"""
if not self._vector_enabled or self._embedding_model is None:
logger.warning("Vector search not available - falling back to text search")
return self._text_search_fallback(query, limit, include_content)
if self._db is None:
raise DatabaseConnectionError(db_path=str(self.db_path))
cursor = self._db.cursor()
try:
# For now, use text search as vector search needs sqlite-vec syntax fixes
logger.info("Using text search fallback temporarily")
return self._text_search_fallback(query, limit, include_content)
# TODO: Fix sqlite-vec query syntax for proper vector search
# Generate query embedding
# query_embedding = self._embedding_model.encode(query)
#
# # Perform vector similarity search using sqlite-vec syntax
# cursor.execute(
# """
# SELECT
# em.conversation_id,
# em.message_id,
# em.created_at,
# m.role,
# m.content,
# c.title,
# vec_distance_l2(e.embedding, ?) as distance
# FROM message_embeddings e
# JOIN embedding_metadata em ON e.rowid = em.rowid
# JOIN messages m ON em.message_id = m.id
# JOIN conversations c ON em.conversation_id = c.id
# WHERE e.embedding MATCH ?
# ORDER BY distance
# LIMIT ?
# """,
# [query_embedding.tolist(), query_embedding.tolist(), limit],
# )
results = []
seen_conversations = set()
for row in cursor.fetchall():
conv_id = row["conversation_id"]
if conv_id not in seen_conversations:
conversation = {
"conversation_id": conv_id,
"title": row["title"],
"similarity_score": 1.0 - row["distance"], # Convert distance to similarity
"matched_message": {
"role": row["role"],
"content": row["content"]
if include_content
else row["content"][:200] + "..."
if len(row["content"]) > 200
else row["content"],
"timestamp": row["created_at"],
},
}
results.append(conversation)
seen_conversations.add(conv_id)
logger.debug(f"Vector search found {len(results)} conversations for query: '{query}'")
return results
except Exception as e:
raise VectorSearchError(query=query, error_details=str(e))
finally:
cursor.close()
def _text_search_fallback(
self, query: str, limit: int, include_content: bool = False
) -> List[Dict[str, Any]]:
"""
Fallback text search when vector search is unavailable.
Args:
query: Search query text
limit: Maximum number of results
include_content: Whether to include full message content
Returns:
List of matching conversations
"""
cursor = self._db.cursor()
try:
# Simple text search in message content
cursor.execute(
"""
SELECT DISTINCT
c.id as conversation_id,
c.title,
m.role,
m.content,
m.timestamp
FROM conversations c
JOIN messages m ON c.id = m.conversation_id
WHERE m.content LIKE ?
ORDER BY m.timestamp DESC
LIMIT ?
""",
[f"%{query}%", limit],
)
results = []
seen_conversations = set()
for row in cursor.fetchall():
conv_id = row["conversation_id"]
if conv_id not in seen_conversations:
conversation = {
"conversation_id": conv_id,
"title": row["title"],
"similarity_score": 0.5, # Default score for text search
"matched_message": {
"role": row["role"],
"content": row["content"]
if include_content
else row["content"][:200] + "..."
if len(row["content"]) > 200
else row["content"],
"timestamp": row["timestamp"],
},
}
results.append(conversation)
seen_conversations.add(conv_id)
logger.debug(
f"Text search fallback found {len(results)} conversations for query: '{query}'"
)
return results
except Exception as e:
logger.error(f"Text search fallback failed: {e}")
return []
finally:
cursor.close()
def get_conversation_list(self, limit: int = 50, offset: int = 0) -> List[Dict[str, Any]]:
"""
Get a list of all conversations with basic info.
Args:
limit: Maximum number of conversations to return
offset: Number of conversations to skip
Returns:
List of conversation summaries
Raises:
MemoryStorageError: If operation fails
"""
if self._db is None:
raise DatabaseConnectionError(db_path=str(self.db_path))
cursor = self._db.cursor()
try:
cursor.execute(
"""
SELECT
c.id,
c.title,
c.created_at,
c.updated_at,
c.metadata,
COUNT(m.id) as message_count
FROM conversations c
LEFT JOIN messages m ON c.id = m.conversation_id
GROUP BY c.id
ORDER BY c.updated_at DESC
LIMIT ? OFFSET ?
""",
[limit, offset],
)
conversations = []
for row in cursor.fetchall():
conversation = {
"id": row["id"],
"title": row["title"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"metadata": json.loads(row["metadata"]),
"message_count": row["message_count"],
}
conversations.append(conversation)
return conversations
except Exception as e:
raise MemoryStorageError(
message=f"Failed to get conversation list: {e}", operation="get_conversation_list"
)
finally:
cursor.close()
def delete_conversation(self, conversation_id: str) -> bool:
"""
Delete a conversation and all its messages.
Args:
conversation_id: ID of conversation to delete
Returns:
True if deleted successfully
Raises:
MemoryStorageError: If deletion fails
"""
if self._db is None:
raise DatabaseConnectionError(db_path=str(self.db_path))
cursor = self._db.cursor()
try:
# Delete conversation (cascade will delete messages and embeddings)
cursor.execute(
"""
DELETE FROM conversations WHERE id = ?
""",
[conversation_id],
)
self._db.commit()
deleted_count = cursor.rowcount
if deleted_count > 0:
logger.info(f"Deleted conversation '{conversation_id}'")
return True
else:
logger.warning(f"Conversation '{conversation_id}' not found for deletion")
return False
except Exception as e:
self._db.rollback()
raise MemoryStorageError(
message=f"Failed to delete conversation: {e}",
operation="delete_conversation",
conversation_id=conversation_id,
)
finally:
cursor.close()
def get_storage_stats(self) -> Dict[str, Any]:
"""
Get storage statistics and health information.
Returns:
Dictionary with storage statistics
Raises:
MemoryStorageError: If operation fails
"""
if self._db is None:
raise DatabaseConnectionError(db_path=str(self.db_path))
cursor = self._db.cursor()
try:
stats = {}
# Count conversations
cursor.execute("SELECT COUNT(*) as count FROM conversations")
stats["conversation_count"] = cursor.fetchone()["count"]
# Count messages
cursor.execute("SELECT COUNT(*) as count FROM messages")
stats["message_count"] = cursor.fetchone()["count"]
# Database file size
if self.db_path.exists():
stats["database_size_bytes"] = self.db_path.stat().st_size
stats["database_size_mb"] = stats["database_size_bytes"] / (1024 * 1024)
else:
stats["database_size_bytes"] = 0
stats["database_size_mb"] = 0
# Vector search capability
stats["vector_search_enabled"] = self._vector_enabled
stats["embedding_model"] = self.embedding_model_name
stats["embedding_dim"] = self._embedding_dim
# Database path
stats["database_path"] = str(self.db_path)
return stats
except Exception as e:
raise MemoryStorageError(
message=f"Failed to get storage stats: {e}", operation="get_storage_stats"
)
finally:
cursor.close()
def close(self) -> None:
"""Close database connection and cleanup resources."""
if self._db:
self._db.close()
self._db = None
logger.info("MemoryStorage database connection closed")
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.close()

View File

@@ -1,14 +0,0 @@
"""
Mai Model Interface Module
This module provides the core interface for interacting with various AI models,
with a focus on local Ollama models. It handles model discovery, capability
detection, and provides a unified interface for model switching and inference.
The model interface is designed to be extensible, allowing future support
for additional model providers while maintaining a consistent API.
"""
from .ollama_client import OllamaClient
__all__ = ["OllamaClient"]

View File

@@ -1,522 +0,0 @@
"""
Context compression and token management for Mai.
Handles conversation context within model token limits while preserving
important information and conversation quality.
"""
import re
from typing import Dict, List, Tuple, Any, Optional
from dataclasses import dataclass
from collections import deque
import hashlib
import json
import time
@dataclass
class TokenInfo:
"""Token counting information."""
count: int
model_name: str
accuracy: float = 0.95 # Confidence in token count accuracy
@dataclass
class CompressionResult:
"""Result of context compression."""
compressed_conversation: List[Dict[str, Any]]
original_tokens: int
compressed_tokens: int
compression_ratio: float
quality_score: float
preserved_elements: List[str]
@dataclass
class BudgetEnforcement:
"""Token budget enforcement result."""
action: str # 'proceed', 'compress', 'reject'
token_count: int
budget_limit: int
urgency: float # 0.0 to 1.0
message: str
class ContextCompressor:
"""
Handles context compression and token management for conversations.
Features:
- Token counting with model-specific accuracy
- Intelligent compression preserving key information
- Budget enforcement to prevent exceeding context windows
- Quality metrics and validation
"""
def __init__(self):
"""Initialize the context compressor."""
self.tiktoken_available = self._check_tiktoken()
if self.tiktoken_available:
import tiktoken
self.encoders = {
"gpt-3.5-turbo": tiktoken.encoding_for_model("gpt-3.5-turbo"),
"gpt-4": tiktoken.encoding_for_model("gpt-4"),
"gpt-4-turbo": tiktoken.encoding_for_model("gpt-4-turbo"),
"text-davinci-003": tiktoken.encoding_for_model("text-davinci-003"),
}
else:
self.encoders = {}
print("Warning: tiktoken not available, using approximate token counting")
# Compression thresholds
self.warning_threshold = 0.75 # Warn at 75% of context window
self.critical_threshold = 0.90 # Critical at 90% of context window
self.budget_ratio = 0.9 # Budget at 90% of context window
# Compression cache
self.compression_cache = {}
self.cache_ttl = 3600 # 1 hour
self.performance_cache = deque(maxlen=100)
# Quality metrics
self.min_quality_score = 0.7
self.preservation_patterns = [
r"\b(install|configure|set up|create|build|implement)\b",
r"\b(error|bug|issue|problem|fix)\b",
r"\b(decision|choice|prefer|selected)\b",
r"\b(important|critical|essential|must)\b",
r"\b(key|main|primary|core)\b",
]
def _check_tiktoken(self) -> bool:
"""Check if tiktoken is available."""
try:
import tiktoken
return True
except ImportError:
return False
def count_tokens(self, text: str, model_name: str = "gpt-3.5-turbo") -> TokenInfo:
"""
Count tokens in text with model-specific accuracy.
Args:
text: Text to count tokens for
model_name: Model name for tokenization
Returns:
TokenInfo with count and accuracy
"""
if not text:
return TokenInfo(0, model_name, 1.0)
if self.tiktoken_available and model_name in self.encoders:
encoder = self.encoders[model_name]
try:
tokens = encoder.encode(text)
return TokenInfo(len(tokens), model_name, 0.99)
except Exception as e:
print(f"Tiktoken error: {e}, falling back to approximation")
# Fallback: approximate token counting
# Rough approximation: ~4 characters per token for English
# Slightly better approach using word and punctuation patterns
words = re.findall(r"\w+|[^\w\s]", text)
# Adjust for model families
model_multipliers = {
"gpt-3.5": 1.0,
"gpt-4": 0.9, # More efficient tokenization
"claude": 1.1, # Less efficient
"llama": 1.2, # Even less efficient
}
# Determine model family
model_family = "gpt-3.5"
for family in model_multipliers:
if family in model_name.lower():
model_family = family
break
multiplier = model_multipliers.get(model_family, 1.0)
token_count = int(len(words) * 1.3 * multiplier) # 1.3 is base conversion
return TokenInfo(token_count, model_name, 0.85) # Lower accuracy for approximation
def should_compress(
self, conversation: List[Dict[str, Any]], model_context_window: int
) -> Tuple[bool, float, str]:
"""
Determine if conversation should be compressed.
Args:
conversation: List of message dictionaries
model_context_window: Model's context window size
Returns:
Tuple of (should_compress, urgency, message)
"""
total_tokens = sum(self.count_tokens(msg.get("content", "")).count for msg in conversation)
usage_ratio = total_tokens / model_context_window
if usage_ratio >= self.critical_threshold:
return True, 1.0, f"Critical: {usage_ratio:.1%} of context window used"
elif usage_ratio >= self.warning_threshold:
return True, 0.7, f"Warning: {usage_ratio:.1%} of context window used"
elif len(conversation) > 50: # Conversation length consideration
return True, 0.5, "Long conversation: consider compression for performance"
else:
return False, 0.0, "Context within acceptable limits"
def preserve_key_elements(self, conversation: List[Dict[str, Any]]) -> List[str]:
"""
Extract and preserve critical information from conversation.
Args:
conversation: List of message dictionaries
Returns:
List of critical elements to preserve
"""
key_elements = []
for msg in conversation:
content = msg.get("content", "")
role = msg.get("role", "")
# Look for important patterns
for pattern in self.preservation_patterns:
matches = re.findall(pattern, content, re.IGNORECASE)
if matches:
# Extract surrounding context
for match in matches:
# Find the sentence containing the match
sentences = re.split(r"[.!?]+", content)
for sentence in sentences:
if match.lower() in sentence.lower():
key_elements.append(f"{role}: {sentence.strip()}")
break
# Also preserve system messages and instructions
for msg in conversation:
if msg.get("role") in ["system", "instruction"]:
key_elements.append(f"system: {msg.get('content', '')}")
return key_elements
def compress_conversation(
self, conversation: List[Dict[str, Any]], target_token_ratio: float = 0.5
) -> CompressionResult:
"""
Compress conversation while preserving key information.
Args:
conversation: List of message dictionaries
target_token_ratio: Target ratio of original tokens to keep
Returns:
CompressionResult with compressed conversation and metrics
"""
if not conversation:
return CompressionResult([], 0, 0, 1.0, 1.0, [])
# Calculate current token usage
original_tokens = sum(
self.count_tokens(msg.get("content", "")).count for msg in conversation
)
target_tokens = int(original_tokens * target_token_ratio)
# Check cache
cache_key = self._get_cache_key(conversation, target_token_ratio)
if cache_key in self.compression_cache:
cached_result = self.compression_cache[cache_key]
if time.time() - cached_result["timestamp"] < self.cache_ttl:
return CompressionResult(**cached_result["result"])
# Preserve key elements
key_elements = self.preserve_key_elements(conversation)
# Split conversation: keep recent messages, compress older ones
split_point = max(0, len(conversation) // 2) # Keep second half
recent_messages = conversation[split_point:]
older_messages = conversation[:split_point]
compressed_messages = []
# Summarize older messages
if older_messages:
summary = self._create_summary(older_messages, target_tokens // 2)
compressed_messages.append(
{
"role": "system",
"content": f"[Compressed context: {summary}]",
"metadata": {
"compressed": True,
"original_count": len(older_messages),
"summary_token_count": self.count_tokens(summary).count,
},
}
)
# Add recent messages
compressed_messages.extend(recent_messages)
# Add key elements if they might be lost
if key_elements:
key_content = "\n\nKey information to remember:\n" + "\n".join(key_elements[:5])
compressed_messages.append(
{
"role": "system",
"content": key_content,
"metadata": {"type": "key_elements", "preserved_count": len(key_elements)},
}
)
# Calculate metrics
compressed_tokens = sum(
self.count_tokens(msg.get("content", "")).count for msg in compressed_messages
)
compression_ratio = compressed_tokens / original_tokens if original_tokens > 0 else 1.0
quality_score = self._calculate_quality_score(
conversation, compressed_messages, key_elements
)
result = CompressionResult(
compressed_conversation=compressed_messages,
original_tokens=original_tokens,
compressed_tokens=compressed_tokens,
compression_ratio=compression_ratio,
quality_score=quality_score,
preserved_elements=key_elements,
)
# Cache result
self.compression_cache[cache_key] = {"result": result.__dict__, "timestamp": time.time()}
return result
def _create_summary(self, messages: List[Dict[str, Any]], target_tokens: int) -> str:
"""
Create a summary of older messages.
Args:
messages: List of message dictionaries
target_tokens: Target token count for summary
Returns:
Summary string
"""
# Extract key points from messages
key_points = []
for msg in messages:
content = msg.get("content", "")
role = msg.get("role", "")
# Extract first sentence or important parts
sentences = re.split(r"[.!?]+", content)
if sentences:
first_sentence = sentences[0].strip()
if len(first_sentence) > 10: # Skip very short fragments
key_points.append(f"{role}: {first_sentence}")
# Join and truncate to target length
summary = " | ".join(key_points)
# Truncate if too long
while len(summary) > target_tokens * 4 and key_points: # Rough character estimate
key_points.pop()
summary = " | ".join(key_points)
return summary if summary else "Previous conversation context"
def _calculate_quality_score(
self,
original: List[Dict[str, Any]],
compressed: List[Dict[str, Any]],
preserved_elements: List[str],
) -> float:
"""
Calculate quality score for compression.
Args:
original: Original conversation
compressed: Compressed conversation
preserved_elements: Elements preserved
Returns:
Quality score between 0.0 and 1.0
"""
# Base score from token preservation
original_tokens = sum(self.count_tokens(msg.get("content", "")).count for msg in original)
compressed_tokens = sum(
self.count_tokens(msg.get("content", "")).count for msg in compressed
)
preservation_score = min(1.0, compressed_tokens / original_tokens)
# Bonus for preserved elements
element_bonus = min(0.2, len(preserved_elements) * 0.02)
# Penalty for too aggressive compression
if compressed_tokens < original_tokens * 0.3:
preservation_score *= 0.8
quality_score = min(1.0, preservation_score + element_bonus)
return quality_score
def enforce_token_budget(
self,
conversation: List[Dict[str, Any]],
model_context_window: int,
budget_ratio: Optional[float] = None,
) -> BudgetEnforcement:
"""
Enforce token budget before model call.
Args:
conversation: List of message dictionaries
model_context_window: Model's context window size
budget_ratio: Budget ratio (default from config)
Returns:
BudgetEnforcement with action and details
"""
if budget_ratio is None:
budget_ratio = self.budget_ratio
budget_limit = int(model_context_window * budget_ratio)
current_tokens = sum(
self.count_tokens(msg.get("content", "")).count for msg in conversation
)
usage_ratio = current_tokens / model_context_window
if current_tokens > budget_limit:
if usage_ratio >= 0.95:
return BudgetEnforcement(
action="reject",
token_count=current_tokens,
budget_limit=budget_limit,
urgency=1.0,
message=f"Conversation too long: {current_tokens} tokens exceeds budget of {budget_limit}",
)
else:
return BudgetEnforcement(
action="compress",
token_count=current_tokens,
budget_limit=budget_limit,
urgency=min(1.0, usage_ratio),
message=f"Compression needed: {current_tokens} tokens exceeds budget of {budget_limit}",
)
else:
urgency = max(0.0, usage_ratio - 0.7) / 0.2 # Normalize between 0.7-0.9
return BudgetEnforcement(
action="proceed",
token_count=current_tokens,
budget_limit=budget_limit,
urgency=urgency,
message=f"Within budget: {current_tokens} tokens of {budget_limit}",
)
def validate_compression(
self, original: List[Dict[str, Any]], compressed: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Validate compression quality and information preservation.
Args:
original: Original conversation
compressed: Compressed conversation
Returns:
Dictionary with validation metrics
"""
# Token-based metrics
original_tokens = sum(self.count_tokens(msg.get("content", "")).count for msg in original)
compressed_tokens = sum(
self.count_tokens(msg.get("content", "")).count for msg in compressed
)
# Semantic similarity (simplified)
original_text = " ".join(msg.get("content", "") for msg in original).lower()
compressed_text = " ".join(msg.get("content", "") for msg in compressed).lower()
# Word overlap as simple similarity metric
original_words = set(re.findall(r"\w+", original_text))
compressed_words = set(re.findall(r"\w+", compressed_text))
if original_words:
similarity = len(original_words & compressed_words) / len(original_words)
else:
similarity = 1.0
# Key information preservation
original_key = self.preserve_key_elements(original)
compressed_key = self.preserve_key_elements(compressed)
key_preservation = len(compressed_key) / max(1, len(original_key))
return {
"token_preservation": compressed_tokens / max(1, original_tokens),
"semantic_similarity": similarity,
"key_information_preservation": key_preservation,
"overall_quality": (similarity + key_preservation) / 2,
"recommendations": self._get_validation_recommendations(
similarity, key_preservation, compressed_tokens / max(1, original_tokens)
),
}
def _get_validation_recommendations(
self, similarity: float, key_preservation: float, token_ratio: float
) -> List[str]:
"""Get recommendations based on validation metrics."""
recommendations = []
if similarity < 0.7:
recommendations.append("Low semantic similarity - consider preserving more context")
if key_preservation < 0.8:
recommendations.append(
"Key information not well preserved - adjust preservation patterns"
)
if token_ratio > 0.8:
recommendations.append("Compression too conservative - can reduce more")
elif token_ratio < 0.3:
recommendations.append("Compression too aggressive - losing too much content")
if not recommendations:
recommendations.append("Compression quality is acceptable")
return recommendations
def _get_cache_key(self, conversation: List[Dict[str, Any]], target_ratio: float) -> str:
"""Generate cache key for compression result."""
# Create hash of conversation and target ratio
content = json.dumps([msg.get("content", "") for msg in conversation], sort_keys=True)
content_hash = hashlib.md5(content.encode()).hexdigest()
return f"{content_hash}_{target_ratio}"
def get_performance_stats(self) -> Dict[str, Any]:
"""Get performance statistics for the compressor."""
return {
"cache_size": len(self.compression_cache),
"cache_hit_ratio": len(self.performance_cache) / max(1, len(self.compression_cache)),
"tiktoken_available": self.tiktoken_available,
"supported_models": list(self.encoders.keys()) if self.tiktoken_available else [],
"compression_thresholds": {
"warning": self.warning_threshold,
"critical": self.critical_threshold,
"budget": self.budget_ratio,
},
}

View File

@@ -1,316 +0,0 @@
"""
Ollama Client Wrapper
Provides a robust wrapper around the Ollama Python client with model discovery,
capability detection, caching, and error handling.
"""
import logging
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import ollama
from src.mai.core import ModelError, ConfigurationError
logger = logging.getLogger(__name__)
class OllamaClient:
"""
Robust wrapper for Ollama API with model discovery and caching.
This client handles connection management, model discovery, capability
detection, and graceful error handling for Ollama operations.
"""
def __init__(self, host: str = "http://localhost:11434", timeout: int = 30):
"""
Initialize Ollama client with connection settings.
Args:
host: Ollama server URL
timeout: Connection timeout in seconds
"""
self.host = host
self.timeout = timeout
self._client = None
self._model_cache: Dict[str, Dict[str, Any]] = {}
self._cache_timestamp: Optional[datetime] = None
self._cache_duration = timedelta(minutes=30)
# Initialize client (may fail if Ollama not running)
self._initialize_client()
def _initialize_client(self) -> None:
"""Initialize Ollama client with error handling."""
try:
self._client = ollama.Client(host=self.host, timeout=self.timeout)
logger.info(f"Ollama client initialized for {self.host}")
except Exception as e:
logger.warning(f"Failed to initialize Ollama client: {e}")
self._client = None
def _check_client(self) -> None:
"""Check if client is initialized, attempt reconnection if needed."""
if self._client is None:
logger.info("Attempting to reconnect to Ollama...")
self._initialize_client()
if self._client is None:
raise ModelError("Cannot connect to Ollama. Is it running?")
def list_models(self) -> List[Dict[str, Any]]:
"""
List all available models with basic metadata.
Returns:
List of models with name and basic info
"""
try:
self._check_client()
if self._client is None:
logger.warning("Ollama client not available")
return []
# Get raw model list from Ollama
response = self._client.list()
models = response.get("models", [])
# Extract relevant information
model_list = []
for model in models:
# Handle both dict and object responses from ollama
if isinstance(model, dict):
model_name = model.get("name", "")
model_size = model.get("size", 0)
model_digest = model.get("digest", "")
model_modified = model.get("modified_at", "")
else:
# Ollama returns model objects with 'model' attribute
model_name = getattr(model, "model", "")
model_size = getattr(model, "size", 0)
model_digest = getattr(model, "digest", "")
model_modified = getattr(model, "modified_at", "")
model_info = {
"name": model_name,
"size": model_size,
"digest": model_digest,
"modified_at": model_modified,
}
model_list.append(model_info)
logger.info(f"Found {len(model_list)} models")
return model_list
except ConnectionError as e:
logger.error(f"Connection error listing models: {e}")
return []
except Exception as e:
logger.error(f"Error listing models: {e}")
return []
def get_model_info(self, model_name: str) -> Dict[str, Any]:
"""
Get detailed information about a specific model.
Args:
model_name: Name of the model
Returns:
Dictionary with model details
"""
# Check cache first
if model_name in self._model_cache:
cache_entry = self._model_cache[model_name]
if (
self._cache_timestamp
and datetime.now() - self._cache_timestamp < self._cache_duration
):
logger.debug(f"Returning cached info for {model_name}")
return cache_entry
try:
self._check_client()
if self._client is None:
raise ModelError("Cannot connect to Ollama")
# Get model details from Ollama
response = self._client.show(model_name)
# Extract key information
model_info = {
"name": model_name,
"parameter_size": response.get("details", {}).get("parameter_size", ""),
"context_window": response.get("details", {}).get("context_length", 0),
"model_family": response.get("details", {}).get("families", []),
"model_format": response.get("details", {}).get("format", ""),
"quantization": response.get("details", {}).get("quantization_level", ""),
"size": response.get("details", {}).get("size", 0),
"modelfile": response.get("modelfile", ""),
"template": response.get("template", ""),
"parameters": response.get("parameters", {}),
}
# Cache the result
self._model_cache[model_name] = model_info
self._cache_timestamp = datetime.now()
logger.debug(f"Retrieved info for {model_name}: {model_info['parameter_size']} params")
return model_info
except Exception as e:
error_msg = f"Error getting model info for {model_name}: {e}"
logger.error(error_msg)
raise ModelError(error_msg)
def is_model_available(self, model_name: str) -> bool:
"""
Check if a model is available and can be queried.
Args:
model_name: Name of the model to check
Returns:
True if model exists and is accessible
"""
try:
# First check if model exists in list
models = self.list_models()
model_names = [m["name"] for m in models]
if model_name not in model_names:
logger.debug(f"Model {model_name} not found in available models")
return False
# Try to get model info to verify accessibility
self.get_model_info(model_name)
return True
except (ModelError, Exception) as e:
logger.debug(f"Model {model_name} not accessible: {e}")
return False
def refresh_models(self) -> None:
"""
Force refresh of model list and clear cache.
This method clears all cached information and forces a fresh
query to Ollama for all operations.
"""
logger.info("Refreshing model information...")
# Clear cache
self._model_cache.clear()
self._cache_timestamp = None
# Reinitialize client if needed
if self._client is None:
self._initialize_client()
logger.info("Model cache cleared")
def get_connection_status(self) -> Dict[str, Any]:
"""
Get current connection status and diagnostics.
Returns:
Dictionary with connection status information
"""
status = {
"connected": False,
"host": self.host,
"timeout": self.timeout,
"models_count": 0,
"cache_size": len(self._model_cache),
"cache_valid": False,
"error": None,
}
try:
if self._client is None:
status["error"] = "Client not initialized"
return status
# Try to list models to verify connection
models = self.list_models()
status["connected"] = True
status["models_count"] = len(models)
# Check cache validity
if self._cache_timestamp:
age = datetime.now() - self._cache_timestamp
status["cache_valid"] = age < self._cache_duration
status["cache_age_minutes"] = age.total_seconds() / 60
except Exception as e:
status["error"] = str(e)
logger.debug(f"Connection status check failed: {e}")
return status
def generate_response(
self, prompt: str, model: str, context: Optional[List[Dict[str, Any]]] = None
) -> str:
"""
Generate a response from the specified model.
Args:
prompt: User prompt/message
model: Model name to use
context: Optional conversation context
Returns:
Generated response text
Raises:
ModelError: If generation fails
"""
try:
self._check_client()
if self._client is None:
raise ModelError("Cannot connect to Ollama")
if not model:
raise ModelError("No model specified")
# Build the full prompt with context if provided
if context:
messages = context + [{"role": "user", "content": prompt}]
else:
messages = [{"role": "user", "content": prompt}]
# Generate response using Ollama
response = self._client.chat(model=model, messages=messages, stream=False)
# Extract the response text
result = response.get("message", {}).get("content", "")
if not result:
logger.warning(f"Empty response from {model}")
return "I apologize, but I couldn't generate a response."
logger.debug(f"Generated response from {model}")
return result
except ModelError:
raise
except Exception as e:
error_msg = f"Error generating response from {model}: {e}"
logger.error(error_msg)
raise ModelError(error_msg)
# Convenience function for creating a client
def create_client(host: Optional[str] = None, timeout: int = 30) -> OllamaClient:
"""
Create an OllamaClient with optional configuration.
Args:
host: Optional Ollama server URL
timeout: Connection timeout in seconds
Returns:
Configured OllamaClient instance
"""
return OllamaClient(host=host or "http://localhost:11434", timeout=timeout)

View File

@@ -1,497 +0,0 @@
"""
Resource monitoring for Mai.
Monitors system resources (CPU, RAM, GPU) and provides
resource-aware model selection capabilities.
"""
import time
import platform
from dataclasses import dataclass
from typing import Optional, Dict, List, Tuple, Any
from collections import deque
@dataclass
class ResourceInfo:
"""Current system resource state"""
cpu_percent: float
memory_total_gb: float
memory_available_gb: float
memory_percent: float
gpu_available: bool
gpu_memory_gb: Optional[float] = None
gpu_usage_percent: Optional[float] = None
timestamp: float = 0.0
@dataclass
class MemoryTrend:
"""Memory usage trend analysis"""
current: float
trend: str # 'stable', 'increasing', 'decreasing'
rate: float # GB per minute
confidence: float # 0.0 to 1.0
class ResourceDetector:
"""System resource monitoring with trend analysis"""
def __init__(self):
"""Initialize resource monitoring"""
self.memory_threshold_warning = 80.0 # 80% for warning
self.memory_threshold_critical = 90.0 # 90% for critical
self.history_window = 60 # seconds
self.history_size = 60 # data points
# Resource history tracking
self.memory_history: deque = deque(maxlen=self.history_size)
self.cpu_history: deque = deque(maxlen=self.history_size)
self.timestamps: deque = deque(maxlen=self.history_size)
# GPU detection
self.gpu_available = self._detect_gpu()
self.gpu_info = self._get_gpu_info()
# Initialize psutil if available
self._init_psutil()
def _init_psutil(self):
"""Initialize psutil with fallback"""
try:
import psutil
self.psutil = psutil
self.has_psutil = True
except ImportError:
print("Warning: psutil not available. Resource monitoring will be limited.")
self.psutil = None
self.has_psutil = False
def _detect_gpu(self) -> bool:
"""Detect GPU availability"""
try:
# Try NVIDIA GPU detection
result = subprocess.run(
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader,nounits"],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0 and result.stdout.strip():
return True
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
try:
# Try AMD GPU detection
result = subprocess.run(
["rocm-smi", "--showproductname"], capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
return True
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
# Apple Silicon detection
if platform.system() == "Darwin" and platform.machine() in ["arm64", "arm"]:
return True
return False
def _get_gpu_info(self) -> Dict[str, Any]:
"""Get GPU information"""
info: Dict[str, Any] = {"type": None, "memory_gb": None, "name": None}
try:
# NVIDIA GPU
result = subprocess.run(
["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader,nounits"],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0:
lines = result.stdout.strip().split("\n")
if lines and lines[0]:
parts = lines[0].split(", ")
if len(parts) >= 2:
info["type"] = "nvidia"
info["name"] = parts[0].strip()
info["memory_gb"] = float(parts[1].strip()) / 1024 # Convert MB to GB
except (subprocess.TimeoutExpired, FileNotFoundError, ValueError):
pass
# Apple Silicon
if (
not info["type"]
and platform.system() == "Darwin"
and platform.machine() in ["arm64", "arm"]
):
info["type"] = "apple_silicon"
# Unified memory, estimate based on system memory
if self.has_psutil and self.psutil is not None:
memory = self.psutil.virtual_memory()
info["memory_gb"] = memory.total / (1024**3)
return info
def detect_resources(self) -> ResourceInfo:
"""Get current system resource state (alias for get_current_resources)"""
return self.get_current_resources()
def get_current_resources(self) -> ResourceInfo:
"""Get current system resource state"""
if not self.has_psutil:
# Fallback to basic monitoring
return self._get_fallback_resources()
# CPU usage
cpu_percent = self.psutil.cpu_percent(interval=1) if self.psutil else 0.0
# Memory information
if self.psutil:
memory = self.psutil.virtual_memory()
memory_total_gb = memory.total / (1024**3)
memory_available_gb = memory.available / (1024**3)
memory_percent = memory.percent
else:
# Use fallback values
memory_total_gb = 8.0 # Default assumption
memory_available_gb = 4.0
memory_percent = 50.0
# GPU information
gpu_usage_percent = None
gpu_memory_gb = None
if self.gpu_info["type"] == "nvidia":
try:
result = subprocess.run(
[
"nvidia-smi",
"--query-gpu=utilization.gpu,memory.used",
"--format=csv,noheader,nounits",
],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0:
lines = result.stdout.strip().split("\n")
if lines and lines[0]:
parts = lines[0].split(", ")
if len(parts) >= 2:
gpu_usage_percent = float(parts[0].strip())
gpu_memory_gb = float(parts[1].strip()) / 1024
except (subprocess.TimeoutExpired, ValueError):
pass
current_time = time.time()
resource_info = ResourceInfo(
cpu_percent=cpu_percent,
memory_total_gb=memory_total_gb,
memory_available_gb=memory_available_gb,
memory_percent=memory_percent,
gpu_available=self.gpu_available,
gpu_memory_gb=gpu_memory_gb,
gpu_usage_percent=gpu_usage_percent,
timestamp=current_time,
)
# Update history
self._update_history(resource_info)
return resource_info
def _get_fallback_resources(self) -> ResourceInfo:
"""Fallback resource detection without psutil"""
# Basic resource detection using /proc filesystem on Linux
cpu_percent = 0.0
memory_total_gb = 0.0
memory_available_gb = 0.0
memory_percent = 0.0
try:
# Read memory info from /proc/meminfo
with open("/proc/meminfo", "r") as f:
meminfo = {}
for line in f:
if ":" in line:
key, value = line.split(":", 1)
meminfo[key.strip()] = int(value.split()[0])
if "MemTotal" in meminfo:
memory_total_gb = meminfo["MemTotal"] / (1024**2)
if "MemAvailable" in meminfo:
memory_available_gb = meminfo["MemAvailable"] / (1024**2)
if memory_total_gb > 0:
memory_percent = (
(memory_total_gb - memory_available_gb) / memory_total_gb
) * 100
except (IOError, KeyError, ValueError):
pass
current_time = time.time()
return ResourceInfo(
cpu_percent=cpu_percent,
memory_total_gb=memory_total_gb,
memory_available_gb=memory_available_gb,
memory_percent=memory_percent,
gpu_available=self.gpu_available,
gpu_memory_gb=self.gpu_info.get("memory_gb"),
gpu_usage_percent=None,
timestamp=current_time,
)
def _update_history(self, resource_info: ResourceInfo):
"""Update resource history for trend analysis"""
current_time = time.time()
self.memory_history.append(resource_info.memory_percent)
self.cpu_history.append(resource_info.cpu_percent)
self.timestamps.append(current_time)
def is_memory_constrained(self) -> Tuple[bool, str]:
"""Check if memory is constrained"""
if not self.memory_history:
resources = self.get_current_resources()
current_memory = resources.memory_percent
else:
current_memory = self.memory_history[-1]
# Check current memory usage
if current_memory >= self.memory_threshold_critical:
return True, "critical"
elif current_memory >= self.memory_threshold_warning:
return True, "warning"
# Check trend
trend = self.get_memory_trend()
if trend.trend == "increasing" and trend.rate > 5.0: # 5GB/min increase
return True, "trend_warning"
return False, "normal"
def get_memory_trend(self) -> MemoryTrend:
"""Analyze memory usage trend over last minute"""
if len(self.memory_history) < 10:
return MemoryTrend(
current=self.memory_history[-1] if self.memory_history else 0.0,
trend="stable",
rate=0.0,
confidence=0.0,
)
# Get recent data points (last 10 measurements)
recent_memory = list(self.memory_history)[-10:]
recent_times = list(self.timestamps)[-10:]
# Calculate trend
if len(recent_memory) >= 2 and len(recent_times) >= 2:
time_span = recent_times[-1] - recent_times[0]
memory_change = recent_memory[-1] - recent_memory[0]
# Convert to GB per minute if we have memory info
rate = 0.0
if self.has_psutil and time_span > 0 and self.psutil is not None:
# Use psutil to get total memory for conversion
total_memory = self.psutil.virtual_memory().total / (1024**3)
rate = (memory_change / 100.0) * total_memory * (60.0 / time_span)
# Determine trend
if abs(memory_change) < 2.0: # Less than 2% change
trend = "stable"
elif memory_change > 0:
trend = "increasing"
else:
trend = "decreasing"
# Confidence based on data consistency
confidence = min(1.0, len(recent_memory) / 10.0)
return MemoryTrend(
current=recent_memory[-1], trend=trend, rate=rate, confidence=confidence
)
return MemoryTrend(
current=recent_memory[-1] if recent_memory else 0.0,
trend="stable",
rate=0.0,
confidence=0.0,
)
def get_performance_degradation(self) -> Dict:
"""Analyze performance degradation metrics"""
if len(self.memory_history) < 20 or len(self.cpu_history) < 20:
return {
"status": "insufficient_data",
"memory_trend": "unknown",
"cpu_trend": "unknown",
"overall": "stable",
}
# Memory trend
memory_trend = self.get_memory_trend()
# CPU trend
recent_cpu = list(self.cpu_history)[-10:]
older_cpu = list(self.cpu_history)[-20:-10]
avg_recent_cpu = sum(recent_cpu) / len(recent_cpu)
avg_older_cpu = sum(older_cpu) / len(older_cpu)
cpu_increase = avg_recent_cpu - avg_older_cpu
# Overall assessment
if memory_trend.trend == "increasing" and memory_trend.rate > 5.0:
memory_status = "worsening"
elif memory_trend.trend == "increasing":
memory_status = "concerning"
else:
memory_status = "stable"
if cpu_increase > 20:
cpu_status = "worsening"
elif cpu_increase > 10:
cpu_status = "concerning"
else:
cpu_status = "stable"
# Overall status
if memory_status == "worsening" or cpu_status == "worsening":
overall = "critical"
elif memory_status == "concerning" or cpu_status == "concerning":
overall = "degrading"
else:
overall = "stable"
return {
"status": "analyzed",
"memory_trend": memory_status,
"cpu_trend": cpu_status,
"cpu_increase": cpu_increase,
"memory_rate": memory_trend.rate,
"overall": overall,
}
def estimate_model_requirements(self, model_size: str) -> Dict:
"""Estimate memory requirements for model size"""
# Conservative estimates based on model parameter count
requirements = {
"1b": {"memory_gb": 2.0, "memory_warning_gb": 2.5, "memory_critical_gb": 3.0},
"3b": {"memory_gb": 4.0, "memory_warning_gb": 5.0, "memory_critical_gb": 6.0},
"7b": {"memory_gb": 8.0, "memory_warning_gb": 10.0, "memory_critical_gb": 12.0},
"13b": {"memory_gb": 16.0, "memory_warning_gb": 20.0, "memory_critical_gb": 24.0},
"70b": {"memory_gb": 80.0, "memory_warning_gb": 100.0, "memory_critical_gb": 120.0},
}
size_key = model_size.lower()
if size_key not in requirements:
# Default to 7B requirements for unknown models
size_key = "7b"
base_req = requirements[size_key]
# Add buffer for context and processing overhead (50%)
context_overhead = base_req["memory_gb"] * 0.5
return {
"size_category": size_key,
"base_memory_gb": base_req["memory_gb"],
"context_overhead_gb": context_overhead,
"total_required_gb": base_req["memory_gb"] + context_overhead,
"warning_threshold_gb": base_req["memory_warning_gb"],
"critical_threshold_gb": base_req["memory_critical_gb"],
}
def can_fit_model(self, model_info: Dict) -> Dict:
"""Check if model fits in current resources"""
# Extract model size info
model_size = model_info.get("size", "7b")
if isinstance(model_size, str):
# Extract numeric size from strings like "7B", "13B", etc.
import re
match = re.search(r"(\d+\.?\d*)[Bb]", model_size)
if match:
size_num = float(match.group(1))
if size_num <= 2:
size_key = "1b"
elif size_num <= 4:
size_key = "3b"
elif size_num <= 10:
size_key = "7b"
elif size_num <= 20:
size_key = "13b"
else:
size_key = "70b"
else:
size_key = "7b"
else:
size_key = str(model_size).lower()
# Get requirements
requirements = self.estimate_model_requirements(size_key)
# Get current resources
current_resources = self.get_current_resources()
# Check memory fit
available_memory = current_resources.memory_available_gb
required_memory = requirements["total_required_gb"]
memory_fit_score = min(1.0, available_memory / required_memory)
# Check performance trends
degradation = self.get_performance_degradation()
# Adjust confidence based on trends
trend_adjustment = 1.0
if degradation["overall"] == "critical":
trend_adjustment = 0.5
elif degradation["overall"] == "degrading":
trend_adjustment = 0.8
confidence = memory_fit_score * trend_adjustment
# GPU consideration
gpu_factor = 1.0
if self.gpu_available and self.gpu_info.get("memory_gb"):
gpu_memory = self.gpu_info["memory_gb"]
if gpu_memory < required_memory:
gpu_factor = 0.5 # GPU might not have enough memory
final_confidence = confidence * gpu_factor
return {
"can_fit": final_confidence >= 0.8,
"confidence": final_confidence,
"memory_fit_score": memory_fit_score,
"trend_adjustment": trend_adjustment,
"gpu_factor": gpu_factor,
"available_memory_gb": available_memory,
"required_memory_gb": required_memory,
"memory_deficit_gb": max(0, required_memory - available_memory),
"recommendation": self._get_fitting_recommendation(final_confidence, requirements),
}
def _get_fitting_recommendation(self, confidence: float, requirements: Dict) -> str:
"""Get recommendation based on fitting assessment"""
if confidence >= 0.9:
return "Excellent fit - model should run smoothly"
elif confidence >= 0.8:
return "Good fit - model should work well"
elif confidence >= 0.6:
return "Possible fit - may experience performance issues"
elif confidence >= 0.4:
return "Tight fit - expect significant slowdowns"
else:
return f"Insufficient resources - need at least {requirements['total_required_gb']:.1f}GB available"
# Required import for subprocess
import subprocess

View File

@@ -1,594 +0,0 @@
"""
Model selection and switching logic for Mai.
Intelligently selects and switches between models based on
available resources and conversation requirements.
"""
import time
import asyncio
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from enum import Enum
class ModelSize(Enum):
"""Model size categories"""
TINY = "1b"
SMALL = "3b"
MEDIUM = "7b"
LARGE = "13b"
HUGE = "70b"
class SwitchReason(Enum):
"""Reasons for model switching"""
RESOURCE_CONSTRAINT = "resource_constraint"
PERFORMANCE_DEGRADATION = "performance_degradation"
TASK_COMPLEXITY = "task_complexity"
USER_REQUEST = "user_request"
ERROR_RECOVERY = "error_recovery"
PROACTIVE_OPTIMIZATION = "proactive_optimization"
@dataclass
class ModelInfo:
"""Information about an available model"""
name: str
size: str # '7b', '13b', etc.
parameters: int # parameter count
context_window: int # context window size
quantization: str # 'q4_0', 'q8_0', etc.
modified_at: Optional[str] = None
digest: Optional[str] = None
def __post_init__(self):
"""Post-processing for model info"""
# Extract parameter count from size string
if isinstance(self.size, str):
import re
match = re.search(r"(\d+\.?\d*)", self.size.lower())
if match:
self.parameters = int(float(match.group(1)) * 1e9)
# Determine size category
if self.parameters <= 2e9:
self.size_category = ModelSize.TINY
elif self.parameters <= 4e9:
self.size_category = ModelSize.SMALL
elif self.parameters <= 10e9:
self.size_category = ModelSize.MEDIUM
elif self.parameters <= 20e9:
self.size_category = ModelSize.LARGE
else:
self.size_category = ModelSize.HUGE
self.size_category = ModelSize.MEDIUM # Default override for now
@dataclass
class SwitchMetrics:
"""Metrics for model switching performance"""
switch_time: float
context_transfer_time: float
context_compression_ratio: float
success: bool
error_message: Optional[str] = None
@dataclass
class SwitchRecommendation:
"""Recommendation for model switching"""
should_switch: bool
target_model: Optional[str]
reason: SwitchReason
confidence: float
expected_benefit: str
estimated_cost: Dict[str, float]
class ModelSwitcher:
"""Intelligent model selection and switching"""
def __init__(self, ollama_client, resource_detector):
"""Initialize model switcher with dependencies"""
self.client = ollama_client
self.resource_detector = resource_detector
# Current state
self.current_model: Optional[str] = None
self.current_model_info: Optional[ModelInfo] = None
self.conversation_context: List[Dict] = []
# Switching history and performance
self.switch_history: List[Dict] = []
self.performance_metrics: Dict[str, Any] = {
"switch_count": 0,
"successful_switches": 0,
"average_switch_time": 0.0,
"last_switch_time": None,
}
# Model capability mappings
self.model_requirements = {
ModelSize.TINY: {"memory_gb": 2, "cpu_cores": 2, "context_preference": 0.3},
ModelSize.SMALL: {"memory_gb": 4, "cpu_cores": 4, "context_preference": 0.5},
ModelSize.MEDIUM: {"memory_gb": 8, "cpu_cores": 6, "context_preference": 0.7},
ModelSize.LARGE: {"memory_gb": 16, "cpu_cores": 8, "context_preference": 0.8},
ModelSize.HUGE: {"memory_gb": 80, "cpu_cores": 16, "context_preference": 1.0},
}
# Initialize available models
self.available_models: Dict[str, ModelInfo] = {}
# Models will be refreshed when needed
# Note: _refresh_model_list is async and should be called from async context
async def _refresh_model_list(self):
"""Refresh the list of available models"""
try:
# This would use the actual ollama client
# For now, create mock models
self.available_models = {
"llama3.2:1b": ModelInfo(
name="llama3.2:1b",
size="1b",
parameters=1_000_000_000,
context_window=2048,
quantization="q4_0",
),
"llama3.2:3b": ModelInfo(
name="llama3.2:3b",
size="3b",
parameters=3_000_000_000,
context_window=4096,
quantization="q4_0",
),
"llama3.2:7b": ModelInfo(
name="llama3.2:7b",
size="7b",
parameters=7_000_000_000,
context_window=8192,
quantization="q4_0",
),
"llama3.2:13b": ModelInfo(
name="llama3.2:13b",
size="13b",
parameters=13_000_000_000,
context_window=8192,
quantization="q4_0",
),
}
except Exception as e:
print(f"Error refreshing model list: {e}")
async def select_best_model(
self, task_complexity: str = "medium", conversation_length: int = 0
) -> Tuple[str, float]:
"""Select the best model based on current conditions"""
if not self.available_models:
# Refresh if no models available
await self._refresh_model_list()
if not self.available_models:
raise ValueError("No models available for selection")
# Get current resources
resources = self.resource_detector.get_current_resources()
# Get performance degradation
degradation = self.resource_detector.get_performance_degradation()
# Filter models that can fit
suitable_models = []
for model_name, model_info in self.available_models.items():
# Check if model fits in resources
can_fit_result = self.resource_detector.can_fit_model(
{"size": model_info.size, "parameters": model_info.parameters}
)
if can_fit_result["can_fit"]:
# Calculate score based on capability and efficiency
score = self._calculate_model_score(
model_info, resources, degradation, task_complexity, conversation_length
)
suitable_models.append((model_name, model_info, score))
# Sort by score (descending)
suitable_models.sort(key=lambda x: x[2], reverse=True)
if not suitable_models:
# No suitable models, return the smallest available as fallback
smallest_model = min(self.available_models.items(), key=lambda x: x[1].parameters)
return smallest_model[0], 0.5
best_model_name, best_model_info, best_score = suitable_models[0]
return best_model_name, best_score
def _calculate_model_score(
self,
model_info: ModelInfo,
resources: Any,
degradation: Dict,
task_complexity: str,
conversation_length: int,
) -> float:
"""Calculate score for model selection"""
score = 0.0
# Base score from model capability (size)
capability_scores = {
ModelSize.TINY: 0.3,
ModelSize.SMALL: 0.5,
ModelSize.MEDIUM: 0.7,
ModelSize.LARGE: 0.85,
ModelSize.HUGE: 1.0,
}
score += capability_scores.get(model_info.size_category, 0.7)
# Resource fit bonus
if hasattr(model_info, "size_category"):
resource_fit = self.resource_detector.can_fit_model(
{"size": model_info.size_category.value, "parameters": model_info.parameters}
)
score += resource_fit["confidence"] * 0.3
# Performance degradation penalty
if degradation["overall"] == "critical":
score -= 0.3
elif degradation["overall"] == "degrading":
score -= 0.15
# Task complexity adjustment
complexity_multipliers = {
"simple": {"tiny": 1.2, "small": 1.1, "medium": 0.9, "large": 0.7, "huge": 0.5},
"medium": {"tiny": 0.8, "small": 0.9, "medium": 1.0, "large": 1.1, "huge": 0.9},
"complex": {"tiny": 0.5, "small": 0.7, "medium": 0.9, "large": 1.2, "huge": 1.3},
}
size_key = (
model_info.size_category.value if hasattr(model_info, "size_category") else "medium"
)
mult = complexity_multipliers.get(task_complexity, {}).get(size_key, 1.0)
score *= mult
# Conversation length adjustment (larger context for longer conversations)
if conversation_length > 50:
if model_info.context_window >= 8192:
score += 0.1
elif model_info.context_window < 4096:
score -= 0.2
elif conversation_length > 20:
if model_info.context_window >= 4096:
score += 0.05
return max(0.0, min(1.0, score))
def should_switch_model(self, current_performance_metrics: Dict) -> SwitchRecommendation:
"""Determine if model should be switched"""
if not self.current_model:
# No current model, select best available
return SwitchRecommendation(
should_switch=True,
target_model=None, # Will be selected
reason=SwitchReason.PROACTIVE_OPTIMIZATION,
confidence=1.0,
expected_benefit="Initialize with optimal model",
estimated_cost={"time": 0.0, "memory": 0.0},
)
# Check resource constraints
memory_constrained, constraint_level = self.resource_detector.is_memory_constrained()
if memory_constrained and constraint_level in ["warning", "critical"]:
# Need to switch to smaller model
smaller_model = self._find_smaller_model()
if smaller_model:
benefit = f"Reduce memory usage during {constraint_level} constraint"
return SwitchRecommendation(
should_switch=True,
target_model=smaller_model,
reason=SwitchReason.RESOURCE_CONSTRAINT,
confidence=0.9,
expected_benefit=benefit,
estimated_cost={"time": 2.0, "memory": -4.0},
)
# Check performance degradation
degradation = self.resource_detector.get_performance_degradation()
if degradation["overall"] in ["critical", "degrading"]:
# Consider switching to smaller model
smaller_model = self._find_smaller_model()
if smaller_model and self.current_model_info:
benefit = "Improve responsiveness during performance degradation"
return SwitchRecommendation(
should_switch=True,
target_model=smaller_model,
reason=SwitchReason.PERFORMANCE_DEGRADATION,
confidence=0.8,
expected_benefit=benefit,
estimated_cost={"time": 2.0, "memory": -4.0},
)
# Check if resources are available for larger model
if not memory_constrained and degradation["overall"] == "stable":
# Can we switch to a larger model?
larger_model = self._find_larger_model()
if larger_model:
benefit = "Increase capability with available resources"
return SwitchRecommendation(
should_switch=True,
target_model=larger_model,
reason=SwitchReason.PROACTIVE_OPTIMIZATION,
confidence=0.7,
expected_benefit=benefit,
estimated_cost={"time": 3.0, "memory": 4.0},
)
return SwitchRecommendation(
should_switch=False,
target_model=None,
reason=SwitchReason.PROACTIVE_OPTIMIZATION,
confidence=1.0,
expected_benefit="Current model is optimal",
estimated_cost={"time": 0.0, "memory": 0.0},
)
def _find_smaller_model(self) -> Optional[str]:
"""Find a smaller model than current"""
if not self.current_model_info or not self.available_models:
return None
current_size = getattr(self.current_model_info, "size_category", ModelSize.MEDIUM)
smaller_sizes = [
ModelSize.TINY,
ModelSize.SMALL,
ModelSize.MEDIUM,
ModelSize.LARGE,
ModelSize.HUGE,
]
current_index = smaller_sizes.index(current_size)
# Look for models in smaller categories
for size in smaller_sizes[:current_index]:
for model_name, model_info in self.available_models.items():
if hasattr(model_info, "size_category") and model_info.size_category == size:
# Check if it fits
can_fit = self.resource_detector.can_fit_model(
{"size": size.value, "parameters": model_info.parameters}
)
if can_fit["can_fit"]:
return model_name
return None
def _find_larger_model(self) -> Optional[str]:
"""Find a larger model than current"""
if not self.current_model_info or not self.available_models:
return None
current_size = getattr(self.current_model_info, "size_category", ModelSize.MEDIUM)
larger_sizes = [
ModelSize.TINY,
ModelSize.SMALL,
ModelSize.MEDIUM,
ModelSize.LARGE,
ModelSize.HUGE,
]
current_index = larger_sizes.index(current_size)
# Look for models in larger categories
for size in larger_sizes[current_index + 1 :]:
for model_name, model_info in self.available_models.items():
if hasattr(model_info, "size_category") and model_info.size_category == size:
# Check if it fits
can_fit = self.resource_detector.can_fit_model(
{"size": size.value, "parameters": model_info.parameters}
)
if can_fit["can_fit"]:
return model_name
return None
async def switch_model(
self, new_model_name: str, conversation_context: Optional[List[Dict]] = None
) -> SwitchMetrics:
"""Switch to a new model with context preservation"""
start_time = time.time()
try:
# Validate new model is available
if new_model_name not in self.available_models:
raise ValueError(f"Model {new_model_name} not available")
# Compress conversation context if provided
context_transfer_time = 0.0
compression_ratio = 1.0
compressed_context = conversation_context
if conversation_context:
compress_start = time.time()
compressed_context = self._compress_context(conversation_context)
context_transfer_time = time.time() - compress_start
compression_ratio = len(conversation_context) / max(1, len(compressed_context))
# Perform the switch (mock implementation)
# In real implementation, this would use the ollama client
old_model = self.current_model
self.current_model = new_model_name
self.current_model_info = self.available_models[new_model_name]
if conversation_context and compressed_context is not None:
self.conversation_context = compressed_context
switch_time = time.time() - start_time
# Update performance metrics
self._update_switch_metrics(True, switch_time)
# Record switch in history
self.switch_history.append(
{
"timestamp": time.time(),
"from_model": old_model,
"to_model": new_model_name,
"switch_time": switch_time,
"context_transfer_time": context_transfer_time,
"compression_ratio": compression_ratio,
"success": True,
}
)
return SwitchMetrics(
switch_time=switch_time,
context_transfer_time=context_transfer_time,
context_compression_ratio=compression_ratio,
success=True,
)
except Exception as e:
switch_time = time.time() - start_time
self._update_switch_metrics(False, switch_time)
return SwitchMetrics(
switch_time=switch_time,
context_transfer_time=0.0,
context_compression_ratio=1.0,
success=False,
error_message=str(e),
)
def _compress_context(self, context: List[Dict]) -> List[Dict]:
"""Compress conversation context for transfer"""
# Simple compression strategy - keep recent messages and summaries
if len(context) <= 10:
return context
# Keep first 2 and last 8 messages
compressed = context[:2] + context[-8:]
# Add a summary if we removed significant content
if len(context) > len(compressed):
summary_msg = {
"role": "system",
"content": f"[{len(context) - len(compressed)} earlier messages summarized for context compression]",
}
compressed.insert(2, summary_msg)
return compressed
def _update_switch_metrics(self, success: bool, switch_time: float):
"""Update performance metrics for switching"""
self.performance_metrics["switch_count"] += 1
if success:
self.performance_metrics["successful_switches"] += 1
# Update average switch time
if self.performance_metrics["switch_count"] == 1:
self.performance_metrics["average_switch_time"] = switch_time
else:
current_avg = self.performance_metrics["average_switch_time"]
n = self.performance_metrics["switch_count"]
new_avg = ((n - 1) * current_avg + switch_time) / n
self.performance_metrics["average_switch_time"] = new_avg
self.performance_metrics["last_switch_time"] = time.time()
def get_model_recommendations(self) -> List[Dict]:
"""Get model recommendations based on current state"""
recommendations = []
# Get current resources
resources = self.resource_detector.get_current_resources()
# Get performance degradation
degradation = self.resource_detector.get_performance_degradation()
for model_name, model_info in self.available_models.items():
# Check if model fits
can_fit_result = self.resource_detector.can_fit_model(
{"size": model_info.size, "parameters": model_info.parameters}
)
if can_fit_result["can_fit"]:
# Calculate recommendation score
score = self._calculate_model_score(model_info, resources, degradation, "medium", 0)
recommendation = {
"model": model_name,
"model_info": asdict(model_info),
"can_fit": True,
"fit_confidence": can_fit_result["confidence"],
"performance_score": score,
"memory_deficit_gb": can_fit_result.get("memory_deficit_gb", 0),
"recommendation": can_fit_result.get("recommendation", ""),
"reason": self._get_recommendation_reason(score, resources, model_info),
}
recommendations.append(recommendation)
else:
# Model doesn't fit, but include with explanation
recommendation = {
"model": model_name,
"model_info": asdict(model_info),
"can_fit": False,
"fit_confidence": can_fit_result["confidence"],
"performance_score": 0.0,
"memory_deficit_gb": can_fit_result.get("memory_deficit_gb", 0),
"recommendation": can_fit_result.get("recommendation", ""),
"reason": f"Insufficient memory - need {can_fit_result.get('memory_deficit_gb', 0):.1f}GB more",
}
recommendations.append(recommendation)
# Sort by performance score
recommendations.sort(key=lambda x: x["performance_score"], reverse=True)
return recommendations
def _get_recommendation_reason(
self, score: float, resources: Any, model_info: ModelInfo
) -> str:
"""Get reason for recommendation"""
if score >= 0.8:
return "Excellent fit for current conditions"
elif score >= 0.6:
return "Good choice, should work well"
elif score >= 0.4:
return "Possible fit, may have performance issues"
else:
return "Not recommended for current conditions"
def estimate_switching_cost(
self, from_model: str, to_model: str, context_size: int
) -> Dict[str, float]:
"""Estimate the cost of switching between models"""
# Base time cost
base_time = 1.0 # 1 second base switching time
# Context transfer cost
context_time = context_size * 0.01 # 10ms per message
# Model loading cost (based on size difference)
from_info = self.available_models.get(from_model)
to_info = self.available_models.get(to_model)
if from_info and to_info:
size_diff = abs(to_info.parameters - from_info.parameters) / 1e9
loading_cost = size_diff * 0.5 # 0.5s per billion parameters difference
else:
loading_cost = 2.0 # Default 2 seconds
total_time = base_time + context_time + loading_cost
# Memory cost (temporary increase during switch)
memory_cost = 2.0 if context_size > 20 else 1.0 # GB
return {
"time_seconds": total_time,
"memory_gb": memory_cost,
"context_transfer_time": context_time,
"model_loading_time": loading_cost,
}

View File

@@ -1,40 +0,0 @@
"""
Mai data models package.
Exports all Pydantic models for conversations, memory, and related data structures.
"""
from .conversation import (
Message,
Conversation,
ConversationSummary,
ConversationFilter,
)
from .memory import (
ConversationType,
RelevanceType,
SearchQuery,
RetrievalResult,
MemoryContext,
ContextWeight,
ConversationPattern,
ContextPlacement,
)
__all__ = [
# Conversation models
"Message",
"Conversation",
"ConversationSummary",
"ConversationFilter",
# Memory models
"ConversationType",
"RelevanceType",
"SearchQuery",
"RetrievalResult",
"MemoryContext",
"ContextWeight",
"ConversationPattern",
"ContextPlacement",
]

View File

@@ -1,172 +0,0 @@
"""
Conversation data models for Mai memory system.
Provides Pydantic models for conversations, messages, and related
data structures with proper validation and serialization.
"""
from typing import List, Dict, Any, Optional
from datetime import datetime
from pydantic import BaseModel, Field, validator
import json
class Message(BaseModel):
"""Individual message within a conversation."""
id: str = Field(..., description="Unique message identifier")
role: str = Field(..., description="Message role: 'user', 'assistant', or 'system'")
content: str = Field(..., description="Message content text")
timestamp: str = Field(..., description="ISO timestamp of message")
token_count: Optional[int] = Field(0, description="Token count for message")
metadata: Optional[Dict[str, Any]] = Field(
default_factory=dict, description="Additional message metadata"
)
@validator("role")
def validate_role(cls, v):
"""Validate that role is one of the allowed values."""
allowed_roles = ["user", "assistant", "system"]
if v not in allowed_roles:
raise ValueError(f"Role must be one of: {allowed_roles}")
return v
@validator("timestamp")
def validate_timestamp(cls, v):
"""Validate timestamp format and ensure it's ISO format."""
try:
# Try to parse the timestamp to ensure it's valid
dt = datetime.fromisoformat(v.replace("Z", "+00:00"))
# Return in standard ISO format
return dt.isoformat()
except (ValueError, AttributeError) as e:
raise ValueError(f"Invalid timestamp format: {v}. Must be ISO format.") from e
class Config:
"""Pydantic configuration for Message model."""
json_encoders = {datetime: lambda v: v.isoformat()}
class Conversation(BaseModel):
"""Complete conversation with messages and metadata."""
id: str = Field(..., description="Unique conversation identifier")
title: str = Field(..., description="Human-readable conversation title")
created_at: str = Field(..., description="ISO timestamp when conversation was created")
updated_at: str = Field(..., description="ISO timestamp when conversation was last updated")
messages: List[Message] = Field(
default_factory=list, description="List of messages in chronological order"
)
metadata: Optional[Dict[str, Any]] = Field(
default_factory=dict, description="Additional conversation metadata"
)
message_count: Optional[int] = Field(0, description="Total number of messages")
@validator("messages")
def validate_message_order(cls, v):
"""Ensure messages are in chronological order."""
if not v:
return v
# Sort by timestamp to ensure chronological order
try:
sorted_messages = sorted(
v, key=lambda m: datetime.fromisoformat(m.timestamp.replace("Z", "+00:00"))
)
return sorted_messages
except (ValueError, AttributeError) as e:
raise ValueError("Messages have invalid timestamps") from e
@validator("updated_at")
def validate_updated_timestamp(cls, v, values):
"""Ensure updated_at is not earlier than created_at."""
if "created_at" in values:
try:
created = datetime.fromisoformat(values["created_at"].replace("Z", "+00:00"))
updated = datetime.fromisoformat(v.replace("Z", "+00:00"))
if updated < created:
raise ValueError("updated_at cannot be earlier than created_at")
except (ValueError, AttributeError) as e:
raise ValueError(f"Invalid timestamp comparison: {e}") from e
return v
def add_message(self, message: Message) -> None:
"""
Add a message to the conversation and update timestamps.
Args:
message: Message to add
"""
self.messages.append(message)
self.message_count = len(self.messages)
# Update the updated_at timestamp
self.updated_at = datetime.now().isoformat()
def get_message_count(self) -> int:
"""Get the actual message count."""
return len(self.messages)
def get_latest_message(self) -> Optional[Message]:
"""Get the most recent message in the conversation."""
if not self.messages:
return None
# Return the message with the latest timestamp
return max(
self.messages, key=lambda m: datetime.fromisoformat(m.timestamp.replace("Z", "+00:00"))
)
class Config:
"""Pydantic configuration for Conversation model."""
json_encoders = {datetime: lambda v: v.isoformat()}
class ConversationSummary(BaseModel):
"""Summary of a conversation for search results."""
id: str = Field(..., description="Conversation identifier")
title: str = Field(..., description="Conversation title")
created_at: str = Field(..., description="Creation timestamp")
updated_at: str = Field(..., description="Last update timestamp")
message_count: int = Field(..., description="Total messages in conversation")
preview: Optional[str] = Field(None, description="Short preview of conversation content")
tags: Optional[List[str]] = Field(
default_factory=list, description="Tags or keywords for conversation"
)
class Config:
"""Pydantic configuration for ConversationSummary model."""
pass
class ConversationFilter(BaseModel):
"""Filter criteria for searching conversations."""
role: Optional[str] = Field(None, description="Filter by message role")
start_date: Optional[str] = Field(
None, description="Filter messages after this date (ISO format)"
)
end_date: Optional[str] = Field(
None, description="Filter messages before this date (ISO format)"
)
keywords: Optional[List[str]] = Field(None, description="Filter by keywords in message content")
min_message_count: Optional[int] = Field(None, description="Minimum message count")
max_message_count: Optional[int] = Field(None, description="Maximum message count")
@validator("start_date", "end_date")
def validate_date_filters(cls, v):
"""Validate date filter format."""
if v is None:
return v
try:
datetime.fromisoformat(v.replace("Z", "+00:00"))
return v
except (ValueError, AttributeError) as e:
raise ValueError(f"Invalid date format: {v}. Must be ISO format.") from e

View File

@@ -1,256 +0,0 @@
"""
Memory system data models for Mai context retrieval.
Provides Pydantic models for memory context, search queries,
retrieval results, and related data structures.
"""
from typing import List, Dict, Any, Optional, Union
from datetime import datetime
from pydantic import BaseModel, Field, validator
from enum import Enum
from .conversation import Conversation, Message
class ConversationType(str, Enum):
"""Enumeration of conversation types for adaptive weighting."""
TECHNICAL = "technical"
PERSONAL = "personal"
PLANNING = "planning"
GENERAL = "general"
QUESTION = "question"
CREATIVE = "creative"
ANALYSIS = "analysis"
class RelevanceType(str, Enum):
"""Enumeration of relevance types for search results."""
SEMANTIC = "semantic"
KEYWORD = "keyword"
RECENCY = "recency"
PATTERN = "pattern"
HYBRID = "hybrid"
class SearchQuery(BaseModel):
"""Query model for context search operations."""
text: str = Field(..., description="Search query text")
conversation_type: Optional[ConversationType] = Field(
None, description="Detected conversation type"
)
filters: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Search filters")
weights: Optional[Dict[str, float]] = Field(
default_factory=dict, description="Search weight overrides"
)
limits: Optional[Dict[str, int]] = Field(default_factory=dict, description="Search limits")
# Default limits
max_results: int = Field(5, description="Maximum number of results to return")
max_tokens: int = Field(2000, description="Maximum tokens in returned context")
# Search facet controls
include_semantic: bool = Field(True, description="Include semantic similarity search")
include_keywords: bool = Field(True, description="Include keyword matching")
include_recency: bool = Field(True, description="Include recency weighting")
include_patterns: bool = Field(True, description="Include pattern matching")
@validator("text")
def validate_text(cls, v):
"""Validate search text is not empty."""
if not v or not v.strip():
raise ValueError("Search text cannot be empty")
return v.strip()
@validator("max_results")
def validate_max_results(cls, v):
"""Validate max results is reasonable."""
if v < 1:
raise ValueError("max_results must be at least 1")
if v > 20:
raise ValueError("max_results cannot exceed 20")
return v
class RetrievalResult(BaseModel):
"""Single result from context retrieval operation."""
conversation_id: str = Field(..., description="ID of the conversation")
title: str = Field(..., description="Title of the conversation")
similarity_score: float = Field(..., description="Similarity score (0.0 to 1.0)")
relevance_type: RelevanceType = Field(..., description="Type of relevance")
excerpt: str = Field(..., description="Relevant excerpt from conversation")
context_type: Optional[ConversationType] = Field(None, description="Type of conversation")
matched_message_id: Optional[str] = Field(None, description="ID of the best matching message")
metadata: Optional[Dict[str, Any]] = Field(
default_factory=dict, description="Additional result metadata"
)
# Component scores for hybrid results
semantic_score: Optional[float] = Field(None, description="Semantic similarity score")
keyword_score: Optional[float] = Field(None, description="Keyword matching score")
recency_score: Optional[float] = Field(None, description="Recency-based score")
pattern_score: Optional[float] = Field(None, description="Pattern matching score")
@validator("similarity_score")
def validate_similarity_score(cls, v):
"""Validate similarity score is in valid range."""
if not 0.0 <= v <= 1.0:
raise ValueError("similarity_score must be between 0.0 and 1.0")
return v
@validator("excerpt")
def validate_excerpt(cls, v):
"""Validate excerpt is not empty."""
if not v or not v.strip():
raise ValueError("excerpt cannot be empty")
return v.strip()
class MemoryContext(BaseModel):
"""Complete memory context for current query."""
current_query: SearchQuery = Field(..., description="The search query")
relevant_conversations: List[RetrievalResult] = Field(
default_factory=list, description="Retrieved conversations"
)
patterns: Optional[Dict[str, Any]] = Field(
default_factory=dict, description="Extracted patterns"
)
metadata: Optional[Dict[str, Any]] = Field(
default_factory=dict, description="Additional context metadata"
)
# Context statistics
total_conversations: int = Field(0, description="Total conversations found")
total_tokens: int = Field(0, description="Total tokens in retrieved context")
context_quality_score: Optional[float] = Field(
None, description="Quality assessment of context"
)
# Weighting information
applied_weights: Optional[Dict[str, float]] = Field(
default_factory=dict, description="Weights applied to search"
)
conversation_type_detected: Optional[ConversationType] = Field(
None, description="Detected conversation type"
)
def add_result(self, result: RetrievalResult) -> None:
"""Add a retrieval result to the context."""
self.relevant_conversations.append(result)
self.total_conversations = len(self.relevant_conversations)
# Estimate tokens (rough approximation: 1 token ≈ 4 characters)
self.total_tokens += len(result.excerpt) // 4
def is_within_token_limit(self, max_tokens: Optional[int] = None) -> bool:
"""Check if context is within token limits."""
limit = max_tokens or self.current_query.max_tokens
return self.total_tokens <= limit
def get_summary_text(self, max_chars: int = 500) -> str:
"""Get a summary of the retrieved context."""
if not self.relevant_conversations:
return "No relevant conversations found."
summaries = []
total_chars = 0
for result in self.relevant_conversations[:3]: # Top 3 results
summary = f"{result.title}: {result.excerpt[:200]}..."
if total_chars + len(summary) > max_chars:
break
summaries.append(summary)
total_chars += len(summary)
return " | ".join(summaries)
class Config:
"""Pydantic configuration for MemoryContext model."""
pass
class ContextWeight(BaseModel):
"""Weight configuration for different search facets."""
semantic: float = Field(0.4, description="Weight for semantic similarity")
keyword: float = Field(0.3, description="Weight for keyword matching")
recency: float = Field(0.2, description="Weight for recency")
pattern: float = Field(0.1, description="Weight for pattern matching")
@validator("semantic", "keyword", "recency", "pattern")
def validate_weights(cls, v):
"""Validate individual weights are non-negative."""
if v < 0:
raise ValueError("Weights cannot be negative")
return v
@validator("semantic", "keyword", "recency", "pattern")
def validate_weight_range(cls, v):
"""Validate weights are reasonable."""
if v > 2.0:
raise ValueError("Individual weights cannot exceed 2.0")
return v
def normalize(self) -> "ContextWeight":
"""Normalize weights so they sum to 1.0."""
total = self.semantic + self.keyword + self.recency + self.pattern
if total == 0:
return ContextWeight()
return ContextWeight(
semantic=self.semantic / total,
keyword=self.keyword / total,
recency=self.recency / total,
pattern=self.pattern / total,
)
class ConversationPattern(BaseModel):
"""Extracted pattern from conversations."""
pattern_type: str = Field(..., description="Type of pattern (preference, topic, style, etc.)")
pattern_value: str = Field(..., description="Pattern value or description")
confidence: float = Field(..., description="Confidence score for pattern")
frequency: int = Field(1, description="How often this pattern appears")
conversation_ids: List[str] = Field(
default_factory=list, description="Conversations where pattern appears"
)
last_seen: str = Field(..., description="ISO timestamp when pattern was last observed")
@validator("confidence")
def validate_confidence(cls, v):
"""Validate confidence score."""
if not 0.0 <= v <= 1.0:
raise ValueError("confidence must be between 0.0 and 1.0")
return v
class Config:
"""Pydantic configuration for ConversationPattern model."""
pass
class ContextPlacement(BaseModel):
"""Strategy for placing context to prevent 'lost in middle'."""
strategy: str = Field(..., description="Placement strategy name")
reasoning: str = Field(..., description="Why this strategy was chosen")
high_priority_items: List[int] = Field(
default_factory=list, description="Indices of high priority conversations"
)
distributed_items: List[int] = Field(
default_factory=list, description="Indices of distributed conversations"
)
token_allocation: Dict[str, int] = Field(
default_factory=dict, description="Token allocation per conversation"
)
class Config:
"""Pydantic configuration for ContextPlacement model."""
pass

View File

@@ -1,29 +0,0 @@
"""
Mai Sandbox System - Safe Code Execution
This module provides the foundational safety infrastructure for Mai's code execution,
including risk analysis, resource enforcement, and audit logging.
"""
from .audit_logger import AuditLogger
from .approval_system import ApprovalSystem
from .docker_executor import ContainerConfig, ContainerResult, DockerExecutor
from .manager import ExecutionRequest, ExecutionResult, SandboxManager
from .resource_enforcer import ResourceEnforcer, ResourceLimits, ResourceUsage
from .risk_analyzer import RiskAnalyzer, RiskAssessment
__all__ = [
"SandboxManager",
"ExecutionRequest",
"ExecutionResult",
"RiskAnalyzer",
"RiskAssessment",
"ResourceEnforcer",
"ResourceLimits",
"ResourceUsage",
"AuditLogger",
"ApprovalSystem",
"DockerExecutor",
"ContainerConfig",
"ContainerResult",
]

View File

@@ -1,431 +0,0 @@
"""
Risk-based User Approval System
This module provides a sophisticated approval system that evaluates code execution
requests based on risk analysis and provides appropriate user interaction workflows.
"""
import logging
import json
import hashlib
from typing import Dict, List, Optional, Tuple, Any
from enum import Enum
from dataclasses import dataclass, asdict
from datetime import datetime
import sys
import re
from ..core.config import get_config
class RiskLevel(Enum):
"""Risk levels for code execution."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
BLOCKED = "blocked"
class ApprovalResult(Enum):
"""Approval decision results."""
ALLOWED = "allowed"
DENIED = "denied"
BLOCKED = "blocked"
APPROVED = "approved"
@dataclass
class RiskAnalysis:
"""Risk analysis result."""
risk_level: RiskLevel
confidence: float
reasons: List[str]
affected_resources: List[str]
severity_score: float
@dataclass
class ApprovalRequest:
"""Approval request data."""
code: str
risk_analysis: RiskAnalysis
context: Dict[str, Any]
timestamp: datetime
request_id: str
user_preference: Optional[str] = None
@dataclass
class ApprovalDecision:
"""Approval decision record."""
request: ApprovalRequest
result: ApprovalResult
user_input: str
timestamp: datetime
trust_updated: bool = False
class ApprovalSystem:
"""Risk-based approval system for code execution."""
def __init__(self):
self.config = get_config()
self.logger = logging.getLogger(__name__)
self.approval_history: List[ApprovalDecision] = []
self.user_preferences: Dict[str, str] = {}
self.trust_patterns: Dict[str, int] = {}
# Risk thresholds - use defaults since sandbox config not yet in main Config
self.risk_thresholds = {
"low_threshold": 0.3,
"medium_threshold": 0.6,
"high_threshold": 0.8,
}
# Load saved preferences
self._load_preferences()
def _load_preferences(self):
"""Load user preferences from configuration."""
try:
# For now, preferences are stored locally only
# TODO: Integrate with Config class when sandbox config added
self.user_preferences = {}
except Exception as e:
self.logger.warning(f"Could not load user preferences: {e}")
def _save_preferences(self):
"""Save user preferences to configuration."""
try:
# Note: This would integrate with config hot-reload system
pass
except Exception as e:
self.logger.warning(f"Could not save user preferences: {e}")
def _generate_request_id(self, code: str) -> str:
"""Generate unique request ID for code."""
content = f"{code}_{datetime.now().isoformat()}"
return hashlib.md5(content.encode()).hexdigest()[:12]
def _analyze_code_risk(self, code: str, context: Dict[str, Any]) -> RiskAnalysis:
"""Analyze code for potential risks."""
risk_patterns = {
"HIGH": [
r"os\.system\s*\(",
r"subprocess\.call\s*\(",
r"exec\s*\(",
r"eval\s*\(",
r"__import__\s*\(",
r'open\s*\([\'"]\/',
r"shutil\.rmtree",
r"pickle\.loads?",
],
"MEDIUM": [
r"import\s+os",
r"import\s+subprocess",
r"import\s+sys",
r"open\s*\(",
r"file\s*\(",
r"\.write\s*\(",
r"\.read\s*\(",
],
}
risk_score = 0.0
reasons = []
affected_resources = []
# Check for high-risk patterns
for pattern in risk_patterns["HIGH"]:
if re.search(pattern, code, re.IGNORECASE):
risk_score += 0.4
reasons.append(f"High-risk pattern detected: {pattern}")
affected_resources.append("system_operations")
# Check for medium-risk patterns
for pattern in risk_patterns["MEDIUM"]:
if re.search(pattern, code, re.IGNORECASE):
risk_score += 0.2
reasons.append(f"Medium-risk pattern detected: {pattern}")
affected_resources.append("file_system")
# Analyze context
if context.get("user_level") == "new":
risk_score += 0.1
reasons.append("New user profile")
# Determine risk level
if risk_score >= self.risk_thresholds["high_threshold"]:
risk_level = RiskLevel.HIGH
elif risk_score >= self.risk_thresholds["medium_threshold"]:
risk_level = RiskLevel.MEDIUM
elif risk_score >= self.risk_thresholds["low_threshold"]:
risk_level = RiskLevel.LOW
else:
risk_level = RiskLevel.LOW # Default to low for very safe code
# Check for blocked operations
blocked_patterns = [
r"rm\s+-rf\s+\/",
r"dd\s+if=",
r"format\s+",
r"fdisk",
]
for pattern in blocked_patterns:
if re.search(pattern, code, re.IGNORECASE):
risk_level = RiskLevel.BLOCKED
reasons.append(f"Blocked operation detected: {pattern}")
break
confidence = min(0.95, 0.5 + len(reasons) * 0.1)
return RiskAnalysis(
risk_level=risk_level,
confidence=confidence,
reasons=reasons,
affected_resources=affected_resources,
severity_score=risk_score,
)
def _present_approval_request(self, request: ApprovalRequest) -> str:
"""Present approval request to user based on risk level."""
risk_level = request.risk_analysis.risk_level
if risk_level == RiskLevel.LOW:
return self._present_low_risk_request(request)
elif risk_level == RiskLevel.MEDIUM:
return self._present_medium_risk_request(request)
elif risk_level == RiskLevel.HIGH:
return self._present_high_risk_request(request)
else: # BLOCKED
return self._present_blocked_request(request)
def _present_low_risk_request(self, request: ApprovalRequest) -> str:
"""Present low-risk approval request."""
print(f"\n🟢 [LOW RISK] Execute {self._get_operation_type(request.code)}?")
print(f"Code: {request.code[:100]}{'...' if len(request.code) > 100 else ''}")
response = input("Allow? [Y/n/a(llow always)]: ").strip().lower()
if response in ["", "y", "yes"]:
return "allowed"
elif response == "a":
self.user_preferences[self._get_operation_type(request.code)] = "auto_allow"
return "allowed_always"
else:
return "denied"
def _present_medium_risk_request(self, request: ApprovalRequest) -> str:
"""Present medium-risk approval request with details."""
print(f"\n🟡 [MEDIUM RISK] Potentially dangerous operation detected")
print(f"Operation Type: {self._get_operation_type(request.code)}")
print(f"Affected Resources: {', '.join(request.risk_analysis.affected_resources)}")
print(f"Risk Factors: {len(request.risk_analysis.reasons)}")
print(f"\nCode Preview:")
print(request.code[:200] + ("..." if len(request.code) > 200 else ""))
if request.risk_analysis.reasons:
print(f"\nRisk Reasons:")
for reason in request.risk_analysis.reasons[:3]:
print(f"{reason}")
response = input("\nAllow this operation? [y/N/d(etails)/a(llow always)]: ").strip().lower()
if response == "y":
return "allowed"
elif response == "d":
return self._present_detailed_view(request)
elif response == "a":
self.user_preferences[self._get_operation_type(request.code)] = "auto_allow"
return "allowed_always"
else:
return "denied"
def _present_high_risk_request(self, request: ApprovalRequest) -> str:
"""Present high-risk approval request with full details."""
print(f"\n🔴 [HIGH RISK] Dangerous operation detected!")
print(f"Severity Score: {request.risk_analysis.severity_score:.2f}")
print(f"Confidence: {request.risk_analysis.confidence:.2f}")
print(f"\nAffected Resources: {', '.join(request.risk_analysis.affected_resources)}")
print(f"\nAll Risk Factors:")
for reason in request.risk_analysis.reasons:
print(f"{reason}")
print(f"\nFull Code:")
print("=" * 50)
print(request.code)
print("=" * 50)
print(f"\n⚠️ This operation could potentially harm your system or data.")
response = (
input("\nType 'confirm' to allow, 'cancel' to deny, 'details' for more info: ")
.strip()
.lower()
)
if response == "confirm":
return "allowed"
elif response == "details":
return self._present_detailed_analysis(request)
else:
return "denied"
def _present_blocked_request(self, request: ApprovalRequest) -> str:
"""Present blocked operation notification."""
print(f"\n🚫 [BLOCKED] Operation not permitted")
print(f"This operation is blocked for security reasons:")
for reason in request.risk_analysis.reasons:
print(f"{reason}")
print("\nThis operation cannot be executed.")
return "blocked"
def _present_detailed_view(self, request: ApprovalRequest) -> str:
"""Present detailed view of the request."""
print(f"\n📋 Detailed Analysis")
print(f"Request ID: {request.request_id}")
print(f"Timestamp: {request.timestamp}")
print(f"Risk Level: {request.risk_analysis.risk_level.value.upper()}")
print(f"Severity Score: {request.risk_analysis.severity_score:.2f}")
print(f"\nContext Information:")
for key, value in request.context.items():
print(f" {key}: {value}")
print(f"\nFull Code:")
print("=" * 50)
print(request.code)
print("=" * 50)
response = input("\nProceed with execution? [y/N]: ").strip().lower()
return "allowed" if response == "y" else "denied"
def _present_detailed_analysis(self, request: ApprovalRequest) -> str:
"""Present extremely detailed analysis for high-risk operations."""
print(f"\n🔬 Security Analysis Report")
print(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Request ID: {request.request_id}")
print(f"\nRisk Assessment:")
print(f" Level: {request.risk_analysis.risk_level.value.upper()}")
print(f" Score: {request.risk_analysis.severity_score:.2f}/1.0")
print(f" Confidence: {request.risk_analysis.confidence:.2f}")
print(f"\nThreat Analysis:")
for reason in request.risk_analysis.reasons:
print(f" ⚠️ {reason}")
print(f"\nResource Impact:")
for resource in request.risk_analysis.affected_resources:
print(f" 📁 {resource}")
print(
f"\nRecommendation: {'DENY' if request.risk_analysis.severity_score > 0.8 else 'REVIEW CAREFULLY'}"
)
response = input("\nFinal decision? [confirm/cancel]: ").strip().lower()
return "allowed" if response == "confirm" else "denied"
def _get_operation_type(self, code: str) -> str:
"""Extract operation type from code."""
if "import" in code:
return "module_import"
elif "os.system" in code or "subprocess" in code:
return "system_command"
elif "open(" in code:
return "file_operation"
elif "print(" in code:
return "output_operation"
else:
return "code_execution"
def request_approval(
self, code: str, context: Optional[Dict[str, Any]] = None
) -> Tuple[ApprovalResult, Optional[ApprovalDecision]]:
"""Request user approval for code execution."""
if context is None:
context = {}
# Analyze risk
risk_analysis = self._analyze_code_risk(code, context)
# Create request
request = ApprovalRequest(
code=code,
risk_analysis=risk_analysis,
context=context,
timestamp=datetime.now(),
request_id=self._generate_request_id(code),
)
# Check user preferences
operation_type = self._get_operation_type(code)
if (
self.user_preferences.get(operation_type) == "auto_allow"
and risk_analysis.risk_level == RiskLevel.LOW
):
decision = ApprovalDecision(
request=request,
result=ApprovalResult.ALLOWED,
user_input="auto_allowed",
timestamp=datetime.now(),
)
self.approval_history.append(decision)
return ApprovalResult.ALLOWED, decision
# Present request based on risk level
user_response = self._present_approval_request(request)
# Convert user response to approval result
if user_response == "blocked":
result = ApprovalResult.BLOCKED
elif user_response in ["allowed", "allowed_always"]:
result = ApprovalResult.APPROVED
else:
result = ApprovalResult.DENIED
# Create decision record
decision = ApprovalDecision(
request=request,
result=result,
user_input=user_response,
timestamp=datetime.now(),
trust_updated=("allowed_always" in user_response),
)
# Save decision
self.approval_history.append(decision)
if decision.trust_updated:
self._save_preferences()
return result, decision
def get_approval_history(self, limit: int = 10) -> List[ApprovalDecision]:
"""Get recent approval history."""
return self.approval_history[-limit:]
def get_trust_patterns(self) -> Dict[str, int]:
"""Get learned trust patterns."""
patterns = {}
for decision in self.approval_history:
op_type = self._get_operation_type(decision.request.code)
if decision.result == ApprovalResult.APPROVED:
patterns[op_type] = patterns.get(op_type, 0) + 1
return patterns
def reset_preferences(self):
"""Reset all user preferences."""
self.user_preferences.clear()
self._save_preferences()
def is_code_safe(self, code: str) -> bool:
"""Quick check if code is considered safe (no approval needed)."""
risk_analysis = self._analyze_code_risk(code, {})
return risk_analysis.risk_level == RiskLevel.LOW and len(risk_analysis.reasons) == 0

View File

@@ -1,442 +0,0 @@
"""
Audit Logging for Mai Sandbox System
Provides immutable, append-only logging with sensitive data masking
and tamper detection for sandbox execution audit trails.
"""
import gzip
import hashlib
import json
import os
import re
import time
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Any
@dataclass
class AuditEntry:
"""Single audit log entry"""
timestamp: str
execution_id: str
code_hash: str
risk_score: int
patterns_detected: list[str]
execution_result: dict[str, Any]
resource_usage: dict[str, Any] | None = None
masked_data: dict[str, str] | None = None
integrity_hash: str | None = None
@dataclass
class LogIntegrity:
"""Log integrity verification result"""
is_valid: bool
tampered_entries: list[int]
hash_chain_valid: bool
last_verified: str
class AuditLogger:
"""
Provides immutable audit logging with sensitive data masking
and tamper detection for sandbox execution tracking.
"""
# Patterns for sensitive data masking
SENSITIVE_PATTERNS = [
(r"\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b", "[EMAIL_REDACTED]"),
(r"\\b(?:\\d{1,3}\\.){3}\\d{1,3}\\b", "[IP_REDACTED]"),
(r"password[\\s]*[:=][\\s]*[^\\s]+", "password=[PASSWORD_REDACTED]"),
(r"api[_-]?key[\\s]*[:=][\\s]*[^\\s]+", "api_key=[API_KEY_REDACTED]"),
(r"token[\\s]*[:=][\\s]*[^\\s]+", "token=[TOKEN_REDACTED]"),
(r"secret[\\s]*[:=][\\s]*[^\\s]+", "secret=[SECRET_REDACTED]"),
(r"bearers?\\s+[^\\s]+", "Bearer [TOKEN_REDACTED]"),
(r"\\b(?:\\d{4}[-\\s]?){3}\\d{4}\\b", "[CREDIT_CARD_REDACTED]"), # Basic CC pattern
(r"\\b\\d{3}-?\\d{2}-?\\d{4}\\b", "[SSN_REDACTED]"),
]
def __init__(self, log_dir: str | None = None, max_file_size_mb: int = 100):
"""
Initialize audit logger
Args:
log_dir: Directory for log files (default: .mai/logs)
max_file_size_mb: Maximum file size before rotation
"""
self.log_dir = Path(log_dir or ".mai/logs")
self.max_file_size = max_file_size_mb * 1024 * 1024 # Convert to bytes
self.current_log_file = None
self.previous_hash = None
# Ensure log directory exists with secure permissions
self.log_dir.mkdir(parents=True, exist_ok=True)
os.chmod(self.log_dir, 0o700) # Only owner can access
# Initialize log file
self._initialize_log_file()
def _initialize_log_file(self):
"""Initialize or find current log file"""
timestamp = datetime.now().strftime("%Y%m%d")
self.current_log_file = self.log_dir / f"sandbox_audit_{timestamp}.jsonl"
# Create file if doesn't exist
if not self.current_log_file.exists():
self.current_log_file.touch()
os.chmod(self.current_log_file, 0o600) # Read/write for owner only
# Load previous hash for integrity chain
self.previous_hash = self._get_last_hash()
def log_execution(
self,
code: str,
execution_result: dict[str, Any],
risk_assessment: dict[str, Any] | None = None,
resource_usage: dict[str, Any] | None = None,
) -> str:
"""
Log code execution with full audit trail
Args:
code: Executed code string
execution_result: Result of execution
risk_assessment: Risk analysis results
resource_usage: Resource usage during execution
Returns:
Execution ID for this log entry
"""
# Generate execution ID and timestamp
execution_id = hashlib.sha256(f"{time.time()}{code[:100]}".encode()).hexdigest()[:16]
timestamp = datetime.now().isoformat()
# Calculate code hash
code_hash = hashlib.sha256(code.encode()).hexdigest()
# Extract risk information
risk_score = 0
patterns_detected = []
if risk_assessment:
risk_score = risk_assessment.get("score", 0)
patterns_detected = [p.get("pattern", "") for p in risk_assessment.get("patterns", [])]
# Mask sensitive data in code
masked_code, masked_info = self.mask_sensitive_data(code)
# Create audit entry
entry = AuditEntry(
timestamp=timestamp,
execution_id=execution_id,
code_hash=code_hash,
risk_score=risk_score,
patterns_detected=patterns_detected,
execution_result=execution_result,
resource_usage=resource_usage,
masked_data=masked_info,
integrity_hash=None, # Will be calculated
)
# Calculate integrity hash with previous hash
entry.integrity_hash = self._calculate_chain_hash(entry)
# Write to log file
self._write_entry(entry)
# Check if rotation needed
if self.current_log_file.stat().st_size > self.max_file_size:
self._rotate_logs()
return execution_id
def mask_sensitive_data(self, text: str) -> tuple[str, dict[str, str]]:
"""
Mask sensitive data patterns in text
Args:
text: Text to mask
Returns:
Tuple of (masked_text, masking_info)
"""
masked_text = text
masking_info = {}
for pattern, replacement in self.SENSITIVE_PATTERNS:
matches = re.findall(pattern, masked_text, re.IGNORECASE)
if matches:
masking_info[pattern] = f"Replaced {len(matches)} instances"
masked_text = re.sub(pattern, replacement, masked_text, flags=re.IGNORECASE)
return masked_text, masking_info
def rotate_logs(self):
"""Rotate current log file with compression"""
if not self.current_log_file.exists():
return
# Compress old log
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
compressed_file = self.log_dir / f"sandbox_audit_{timestamp}.jsonl.gz"
with open(self.current_log_file, "rb") as f_in:
with gzip.open(compressed_file, "wb") as f_out:
f_out.writelines(f_in)
# Remove original
self.current_log_file.unlink()
# Set secure permissions on compressed file
os.chmod(compressed_file, 0o600)
# Reinitialize new log file
self._initialize_log_file()
def verify_integrity(self) -> LogIntegrity:
"""
Verify log file integrity using hash chain
Returns:
LogIntegrity verification result
"""
if not self.current_log_file.exists():
return LogIntegrity(
is_valid=False,
tampered_entries=[],
hash_chain_valid=False,
last_verified=datetime.now().isoformat(),
)
try:
with open(self.current_log_file) as f:
lines = f.readlines()
tampered_entries = []
previous_hash = None
for i, line in enumerate(lines):
try:
entry_data = json.loads(line.strip())
expected_hash = entry_data.get("integrity_hash")
# Recalculate hash without integrity field
entry_data["integrity_hash"] = None
actual_hash = hashlib.sha256(
json.dumps(entry_data, sort_keys=True).encode()
).hexdigest()
if previous_hash:
# Include previous hash in calculation
combined = f"{previous_hash}{actual_hash}"
actual_hash = hashlib.sha256(combined.encode()).hexdigest()
if expected_hash != actual_hash:
tampered_entries.append(i)
previous_hash = expected_hash
except (json.JSONDecodeError, KeyError):
tampered_entries.append(i)
return LogIntegrity(
is_valid=len(tampered_entries) == 0,
tampered_entries=tampered_entries,
hash_chain_valid=len(tampered_entries) == 0,
last_verified=datetime.now().isoformat(),
)
except Exception:
return LogIntegrity(
is_valid=False,
tampered_entries=[],
hash_chain_valid=False,
last_verified=datetime.now().isoformat(),
)
def query_logs(
self, limit: int = 100, risk_min: int = 0, after: str | None = None
) -> list[dict[str, Any]]:
"""
Query audit logs with filters
Args:
limit: Maximum number of entries to return
risk_min: Minimum risk score to include
after: ISO timestamp to filter after
Returns:
List of matching log entries
"""
if not self.current_log_file.exists():
return []
entries = []
try:
with open(self.current_log_file) as f:
for line in f:
if not line.strip():
continue
try:
entry = json.loads(line.strip())
# Apply filters
if entry.get("risk_score", 0) < risk_min:
continue
if after and entry.get("timestamp", "") <= after:
continue
entries.append(entry)
if len(entries) >= limit:
break
except json.JSONDecodeError:
continue
except Exception:
return []
# Return in reverse chronological order
return list(reversed(entries[-limit:]))
def get_execution_by_id(self, execution_id: str) -> dict[str, Any] | None:
"""
Retrieve specific execution by ID
Args:
execution_id: Unique execution identifier
Returns:
Log entry or None if not found
"""
entries = self.query_logs(limit=1000) # Get more for search
for entry in entries:
if entry.get("execution_id") == execution_id:
return entry
return None
def _write_entry(self, entry: AuditEntry):
"""Write entry to log file"""
try:
with open(self.current_log_file, "a") as f:
# Convert to dict and remove None values
entry_dict = {k: v for k, v in asdict(entry).items() if v is not None}
f.write(json.dumps(entry_dict) + "\\n")
f.flush() # Ensure immediate write
# Update previous hash
self.previous_hash = entry.integrity_hash
except Exception as e:
raise RuntimeError(f"Failed to write audit entry: {e}") from e
def _calculate_chain_hash(self, entry: AuditEntry) -> str:
"""Calculate integrity hash for entry with previous hash"""
entry_dict = asdict(entry)
entry_dict["integrity_hash"] = None # Exclude from calculation
# Create hash of entry data
entry_hash = hashlib.sha256(json.dumps(entry_dict, sort_keys=True).encode()).hexdigest()
# Chain with previous hash if exists
if self.previous_hash:
combined = f"{self.previous_hash}{entry_hash}"
return hashlib.sha256(combined.encode()).hexdigest()
return entry_hash
def _get_last_hash(self) -> str | None:
"""Get hash from last entry in log file"""
if not self.current_log_file.exists():
return None
try:
with open(self.current_log_file) as f:
lines = f.readlines()
if not lines:
return None
last_line = lines[-1].strip()
if not last_line:
return None
entry = json.loads(last_line)
return entry.get("integrity_hash")
except (json.JSONDecodeError, FileNotFoundError):
return None
def _rotate_logs(self):
"""Perform log rotation"""
try:
self.rotate_logs()
except Exception as e:
print(f"Log rotation failed: {e}")
def get_log_stats(self) -> dict[str, Any]:
"""
Get statistics about audit logs
Returns:
Dictionary with log statistics
"""
if not self.current_log_file.exists():
return {
"total_entries": 0,
"file_size_bytes": 0,
"high_risk_executions": 0,
"last_execution": None,
}
try:
with open(self.current_log_file) as f:
lines = f.readlines()
entries = []
high_risk_count = 0
for line in lines:
if not line.strip():
continue
try:
entry = json.loads(line.strip())
entries.append(entry)
if entry.get("risk_score", 0) >= 70:
high_risk_count += 1
except json.JSONDecodeError:
continue
file_size = self.current_log_file.stat().st_size
last_execution = entries[-1].get("timestamp") if entries else None
return {
"total_entries": len(entries),
"file_size_bytes": file_size,
"file_size_mb": file_size / (1024 * 1024),
"high_risk_executions": high_risk_count,
"last_execution": last_execution,
"log_file": str(self.current_log_file),
}
except Exception:
return {
"total_entries": 0,
"file_size_bytes": 0,
"high_risk_executions": 0,
"last_execution": None,
}

View File

@@ -1,432 +0,0 @@
"""
Docker Executor for Mai Safe Code Execution
Provides isolated container execution using Docker with comprehensive
resource limits, security restrictions, and audit logging integration.
"""
import logging
import tempfile
import time
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any
try:
import docker
from docker.errors import APIError, ContainerError, DockerException, ImageNotFound
from docker.models.containers import Container
DOCKER_AVAILABLE = True
except ImportError:
docker = None
Container = None
DockerException = Exception
APIError = Exception
ContainerError = Exception
ImageNotFound = Exception
DOCKER_AVAILABLE = False
from .audit_logger import AuditLogger
@dataclass
class ContainerConfig:
"""Configuration for Docker container execution"""
image: str = "python:3.10-slim"
timeout_seconds: int = 30
memory_limit: str = "128m" # Docker memory limit format
cpu_limit: str = "0.5" # CPU quota (0.5 = 50% of one CPU)
network_disabled: bool = True
read_only_filesystem: bool = True
tmpfs_size: str = "64m" # Temporary filesystem size
working_dir: str = "/app"
user: str = "nobody" # Non-root user
@dataclass
class ContainerResult:
"""Result of container execution"""
success: bool
container_id: str
exit_code: int
stdout: str | None = None
stderr: str | None = None
execution_time: float = 0.0
error: str | None = None
resource_usage: dict[str, Any] | None = None
class DockerExecutor:
"""
Docker-based container executor for isolated code execution.
Provides secure sandboxing using Docker containers with resource limits,
network restrictions, and comprehensive audit logging.
"""
def __init__(self, audit_logger: AuditLogger | None = None):
"""
Initialize Docker executor
Args:
audit_logger: Optional audit logger for execution logging
"""
self.audit_logger = audit_logger
self.client = None
self.available = False
# Try to initialize Docker client
self._initialize_docker()
# Setup logging
self.logger = logging.getLogger(__name__)
def _initialize_docker(self) -> None:
"""Initialize Docker client and verify availability"""
if not DOCKER_AVAILABLE:
self.available = False
return
try:
if docker is not None:
self.client = docker.from_env()
# Test Docker connection
self.client.ping()
self.available = True
else:
self.available = False
self.client = None
except Exception as e:
self.logger.warning(f"Docker not available: {e}")
self.available = False
self.client = None
def is_available(self) -> bool:
"""Check if Docker executor is available"""
return self.available and self.client is not None
def execute_code(
self,
code: str,
config: ContainerConfig | None = None,
environment: dict[str, str] | None = None,
files: dict[str, str] | None = None,
) -> ContainerResult:
"""
Execute code in isolated Docker container
Args:
code: Python code to execute
config: Container configuration
environment: Environment variables
files: Additional files to mount in container
Returns:
ContainerResult with execution details
"""
if not self.is_available() or self.client is None:
return ContainerResult(
success=False, container_id="", exit_code=-1, error="Docker executor not available"
)
config = config or ContainerConfig()
container_id = str(uuid.uuid4())[:8]
start_time = time.time()
try:
# Create temporary directory for files
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Write code to file
code_file = temp_path / "execute.py"
code_file.write_text(code)
# Prepare volume mounts
volumes = {
str(code_file): {
"bind": f"{config.working_dir}/execute.py",
"mode": "ro", # read-only
}
}
# Add additional files if provided
if files:
for filename, content in files.items():
file_path = temp_path / filename
file_path.write_text(content)
volumes[str(file_path)] = {
"bind": f"{config.working_dir}/{filename}",
"mode": "ro",
}
# Prepare container configuration
container_config = self._build_container_config(config, environment)
# Create and start container
container = self.client.containers.run(
image=config.image,
command=["python", "execute.py"],
volumes=volumes,
**container_config,
detach=True,
)
# Get container ID safely
container_id = getattr(container, "id", container_id)
try:
# Wait for completion with timeout
result = container.wait(timeout=config.timeout_seconds)
exit_code = result["StatusCode"]
# Get logs
stdout = container.logs(stdout=True, stderr=False).decode("utf-8")
stderr = container.logs(stdout=False, stderr=True).decode("utf-8")
# Get resource usage stats
stats = self._get_container_stats(container)
# Determine success
success = exit_code == 0 and not stderr
execution_result = ContainerResult(
success=success,
container_id=container_id,
exit_code=exit_code,
stdout=stdout,
stderr=stderr,
execution_time=time.time() - start_time,
resource_usage=stats,
)
# Log execution if audit logger available
if self.audit_logger:
self._log_container_execution(code, execution_result, config)
return execution_result
finally:
# Always cleanup container
try:
container.remove(force=True)
except Exception:
pass # Best effort cleanup
except ContainerError as e:
return ContainerResult(
success=False,
container_id=container_id or "unknown",
exit_code=getattr(e, "exit_code", -1),
stderr=str(e),
execution_time=time.time() - start_time,
error=f"Container execution error: {e}",
)
except ImageNotFound as e:
return ContainerResult(
success=False,
container_id=container_id,
exit_code=-1,
error=f"Docker image not found: {e}",
)
except APIError as e:
return ContainerResult(
success=False,
container_id=container_id,
exit_code=-1,
error=f"Docker API error: {e}",
)
except Exception as e:
return ContainerResult(
success=False,
container_id=container_id,
exit_code=-1,
execution_time=time.time() - start_time,
error=f"Unexpected error: {e}",
)
def _build_container_config(
self, config: ContainerConfig, environment: dict[str, str] | None = None
) -> dict[str, Any]:
"""Build Docker container configuration"""
container_config = {
"mem_limit": config.memory_limit,
"cpu_quota": int(float(config.cpu_limit) * 100000), # Convert to microseconds
"cpu_period": 100000, # 100ms period
"network_disabled": config.network_disabled,
"read_only": config.read_only_filesystem,
"tmpfs": {"/tmp": f"size={config.tmpfs_size},noexec,nosuid,nodev"},
"user": config.user,
"working_dir": config.working_dir,
"remove": True, # Auto-remove container
}
# Add environment variables
if environment:
container_config["environment"] = {
**environment,
"PYTHONPATH": config.working_dir,
"PYTHONDONTWRITEBYTECODE": "1",
}
else:
container_config["environment"] = {
"PYTHONPATH": config.working_dir,
"PYTHONDONTWRITEBYTECODE": "1",
}
# Security options
container_config["security_opt"] = [
"no-new-privileges:true",
"seccomp:unconfined", # Python needs some syscalls
]
# Capabilities (drop all capabilities)
container_config["cap_drop"] = ["ALL"]
container_config["cap_add"] = ["CHOWN", "DAC_OVERRIDE"] # Minimal capabilities for Python
return container_config
def _get_container_stats(self, container) -> dict[str, Any]:
"""Get resource usage statistics from container"""
try:
stats = container.stats(stream=False)
# Calculate CPU usage
cpu_stats = stats.get("cpu_stats", {})
precpu_stats = stats.get("precpu_stats", {})
cpu_usage = cpu_stats.get("cpu_usage", {}).get("total_usage", 0)
precpu_usage = precpu_stats.get("cpu_usage", {}).get("total_usage", 0)
system_usage = cpu_stats.get("system_cpu_usage", 0)
presystem_usage = precpu_stats.get("system_cpu_usage", 0)
cpu_count = cpu_stats.get("online_cpus", 1)
cpu_percent = 0.0
if system_usage > presystem_usage:
cpu_delta = cpu_usage - precpu_usage
system_delta = system_usage - presystem_usage
cpu_percent = (cpu_delta / system_delta) * cpu_count * 100.0
# Calculate memory usage
memory_stats = stats.get("memory_stats", {})
memory_usage = memory_stats.get("usage", 0)
memory_limit = memory_stats.get("limit", 1)
memory_percent = (memory_usage / memory_limit) * 100.0
return {
"cpu_percent": round(cpu_percent, 2),
"memory_usage_bytes": memory_usage,
"memory_limit_bytes": memory_limit,
"memory_percent": round(memory_percent, 2),
"memory_usage_mb": round(memory_usage / (1024 * 1024), 2),
}
except Exception:
return {
"cpu_percent": 0.0,
"memory_usage_bytes": 0,
"memory_limit_bytes": 0,
"memory_percent": 0.0,
"memory_usage_mb": 0.0,
}
def _log_container_execution(
self, code: str, result: ContainerResult, config: ContainerConfig
) -> None:
"""Log container execution to audit logger"""
if not self.audit_logger:
return
execution_data = {
"type": "docker_container",
"container_id": result.container_id,
"exit_code": result.exit_code,
"stdout": result.stdout,
"stderr": result.stderr,
"execution_time": result.execution_time,
"config": {
"image": config.image,
"timeout": config.timeout_seconds,
"memory_limit": config.memory_limit,
"cpu_limit": config.cpu_limit,
"network_disabled": config.network_disabled,
"read_only_filesystem": config.read_only_filesystem,
},
"resource_usage": result.resource_usage,
}
# Note: execution_type parameter not available in current AuditLogger
self.audit_logger.log_execution(code=code, execution_result=execution_data)
def get_available_images(self) -> list[str]:
"""Get list of available Docker images"""
if not self.is_available() or self.client is None:
return []
try:
images = self.client.images.list()
return [img.tags[0] for img in images if img.tags]
except Exception:
return []
def pull_image(self, image_name: str) -> bool:
"""Pull Docker image"""
if not self.is_available() or self.client is None:
return False
try:
self.client.images.pull(image_name)
return True
except Exception:
return False
def cleanup_containers(self) -> int:
"""Clean up any dangling containers"""
if not self.is_available() or self.client is None:
return 0
try:
containers = self.client.containers.list(all=True, filters={"status": "exited"})
count = 0
for container in containers:
try:
container.remove(force=True)
count += 1
except Exception:
pass
return count
except Exception:
return 0
def get_system_info(self) -> dict[str, Any]:
"""Get Docker system information"""
if not self.is_available() or self.client is None:
return {"available": False}
try:
info = self.client.info()
version = self.client.version()
return {
"available": True,
"version": version.get("Version", "unknown"),
"api_version": version.get("ApiVersion", "unknown"),
"containers": info.get("Containers", 0),
"containers_running": info.get("ContainersRunning", 0),
"containers_paused": info.get("ContainersPaused", 0),
"containers_stopped": info.get("ContainersStopped", 0),
"images": info.get("Images", 0),
"memory_total": info.get("MemTotal", 0),
"ncpu": info.get("NCPU", 0),
}
except Exception:
return {"available": False, "error": "Failed to get system info"}

View File

@@ -1,439 +0,0 @@
"""
Sandbox Manager for Mai Safe Code Execution
Central orchestrator for sandbox execution, integrating risk analysis,
resource enforcement, and audit logging for safe code execution.
"""
import time
import uuid
from dataclasses import dataclass
from typing import Any
from .audit_logger import AuditLogger
from .docker_executor import ContainerConfig, ContainerResult, DockerExecutor
from .resource_enforcer import ResourceEnforcer, ResourceLimits, ResourceUsage
from .risk_analyzer import RiskAnalyzer, RiskAssessment
@dataclass
class ExecutionRequest:
"""Request for sandbox code execution"""
code: str
environment: dict[str, str] | None = None
timeout_seconds: int = 30
cpu_limit_percent: float = 70.0
memory_limit_percent: float = 70.0
network_allowed: bool = False
filesystem_restricted: bool = True
use_docker: bool = True
docker_image: str = "python:3.10-slim"
additional_files: dict[str, str] | None = None
@dataclass
class ExecutionResult:
"""Result of sandbox execution"""
success: bool
execution_id: str
output: str | None = None
error: str | None = None
risk_assessment: RiskAssessment | None = None
resource_usage: ResourceUsage | None = None
execution_time: float = 0.0
audit_entry_id: str | None = None
execution_method: str = "local" # "local", "docker", "fallback"
container_result: ContainerResult | None = None
class SandboxManager:
"""
Central sandbox orchestrator that coordinates risk analysis,
resource enforcement, and audit logging for safe code execution.
"""
def __init__(self, log_dir: str | None = None):
"""
Initialize sandbox manager
Args:
log_dir: Directory for audit logs
"""
self.risk_analyzer = RiskAnalyzer()
self.resource_enforcer = ResourceEnforcer()
self.audit_logger = AuditLogger(log_dir=log_dir)
self.docker_executor = DockerExecutor(audit_logger=self.audit_logger)
# Execution state
self.active_executions: dict[str, dict[str, Any]] = {}
def execute_code(self, request: ExecutionRequest) -> ExecutionResult:
"""
Execute code in sandbox with full safety checks
Args:
request: ExecutionRequest with code and constraints
Returns:
ExecutionResult with execution details
"""
execution_id = str(uuid.uuid4())[:8]
start_time = time.time()
try:
# Step 1: Risk analysis
risk_assessment = self.risk_analyzer.analyze_ast(request.code)
# Step 2: Check if execution is allowed
if not self._is_execution_allowed(risk_assessment):
result = ExecutionResult(
success=False,
execution_id=execution_id,
error=(
f"Code execution blocked: Risk score {risk_assessment.score} "
"exceeds safe threshold"
),
risk_assessment=risk_assessment,
execution_time=time.time() - start_time,
)
# Log blocked execution
self._log_execution(request, result, risk_assessment)
return result
# Step 3: Set resource limits
resource_limits = ResourceLimits(
cpu_percent=request.cpu_limit_percent,
memory_percent=request.memory_limit_percent,
timeout_seconds=request.timeout_seconds,
)
self.resource_enforcer.set_limits(resource_limits)
self.resource_enforcer.start_monitoring()
# Step 4: Choose execution method and execute code
execution_method = (
"docker" if request.use_docker and self.docker_executor.is_available() else "local"
)
if execution_method == "docker":
execution_result = self._execute_in_docker(request, execution_id)
else:
execution_result = self._execute_in_sandbox(request, execution_id)
execution_method = "local"
# Step 5: Get resource usage (for local execution)
if execution_method == "local":
resource_usage = self.resource_enforcer.stop_monitoring()
else:
resource_usage = None # Docker provides its own resource usage
# Step 6: Create result
result = ExecutionResult(
success=execution_result.get("success", False),
execution_id=execution_id,
output=execution_result.get("output"),
error=execution_result.get("error"),
risk_assessment=risk_assessment,
resource_usage=resource_usage,
execution_time=time.time() - start_time,
execution_method=execution_method,
container_result=execution_result.get("container_result"),
)
# Step 7: Log execution
audit_id = self._log_execution(request, result, risk_assessment, resource_usage)
result.audit_entry_id = audit_id
return result
except Exception as e:
# Handle unexpected errors
result = ExecutionResult(
success=False,
execution_id=execution_id,
error=f"Sandbox execution error: {str(e)}",
execution_time=time.time() - start_time,
)
# Log error
self._log_execution(request, result)
return result
finally:
# Cleanup
self.resource_enforcer.stop_monitoring()
def check_risk(self, code: str) -> RiskAssessment:
"""
Perform risk analysis on code
Args:
code: Code to analyze
Returns:
RiskAssessment with detailed analysis
"""
return self.risk_analyzer.analyze_ast(code)
def enforce_limits(self, limits: ResourceLimits) -> bool:
"""
Set resource limits for execution
Args:
limits: Resource limits to enforce
Returns:
True if limits were set successfully
"""
return self.resource_enforcer.set_limits(limits)
def log_execution(
self,
code: str,
execution_result: dict[str, Any],
risk_assessment: dict[str, Any] | None = None,
resource_usage: dict[str, Any] | None = None,
) -> str:
"""
Log execution details to audit trail
Args:
code: Executed code
execution_result: Result of execution
risk_assessment: Risk analysis results
resource_usage: Resource usage statistics
Returns:
Audit entry ID
"""
return self.audit_logger.log_execution(
code=code,
execution_result=execution_result,
risk_assessment=risk_assessment,
resource_usage=resource_usage,
)
def get_execution_history(
self, limit: int = 50, min_risk_score: int = 0
) -> list[dict[str, Any]]:
"""
Get execution history from audit logs
Args:
limit: Maximum entries to return
min_risk_score: Minimum risk score filter
Returns:
List of execution entries
"""
return self.audit_logger.query_logs(limit=limit, risk_min=min_risk_score)
def verify_log_integrity(self) -> bool:
"""
Verify audit log integrity
Returns:
True if logs are intact
"""
integrity = self.audit_logger.verify_integrity()
return integrity.is_valid
def get_system_status(self) -> dict[str, Any]:
"""
Get current sandbox system status
Returns:
Dictionary with system status
"""
return {
"active_executions": len(self.active_executions),
"resource_monitoring": self.resource_enforcer.monitoring_active,
"current_usage": self.resource_enforcer.monitor_usage(),
"log_stats": self.audit_logger.get_log_stats(),
"log_integrity": self.verify_log_integrity(),
"docker_available": self.docker_executor.is_available(),
"docker_info": self.docker_executor.get_system_info(),
}
def get_docker_status(self) -> dict[str, Any]:
"""
Get Docker executor status and available images
Returns:
Dictionary with Docker status
"""
return {
"available": self.docker_executor.is_available(),
"images": self.docker_executor.get_available_images(),
"system_info": self.docker_executor.get_system_info(),
}
def pull_docker_image(self, image_name: str) -> bool:
"""
Pull a Docker image for execution
Args:
image_name: Name of the Docker image to pull
Returns:
True if image was pulled successfully
"""
return self.docker_executor.pull_image(image_name)
def cleanup_docker_containers(self) -> int:
"""
Clean up any dangling Docker containers
Returns:
Number of containers cleaned up
"""
return self.docker_executor.cleanup_containers()
def _is_execution_allowed(self, risk_assessment: RiskAssessment) -> bool:
"""
Determine if execution is allowed based on risk assessment
Args:
risk_assessment: Risk analysis result
Returns:
True if execution is allowed
"""
# Block if any BLOCKED patterns detected
blocked_patterns = [p for p in risk_assessment.patterns if p.severity == "BLOCKED"]
if blocked_patterns:
return False
# Require approval for HIGH risk
if risk_assessment.score >= 70:
return False # Would require user approval in full implementation
return True
def _execute_in_docker(self, request: ExecutionRequest, execution_id: str) -> dict[str, Any]:
"""
Execute code in Docker container
Args:
request: Execution request
execution_id: Unique execution identifier
Returns:
Dictionary with execution result
"""
# Create container configuration based on request
config = ContainerConfig(
image=request.docker_image,
timeout_seconds=request.timeout_seconds,
memory_limit=f"{int(request.memory_limit_percent * 128 / 100)}m", # Scale to container
cpu_limit=str(request.cpu_limit_percent / 100),
network_disabled=not request.network_allowed,
read_only_filesystem=request.filesystem_restricted,
)
# Execute in Docker container
container_result = self.docker_executor.execute_code(
code=request.code,
config=config,
environment=request.environment,
files=request.additional_files,
)
return {
"success": container_result.success,
"output": container_result.stdout,
"error": container_result.stderr or container_result.error,
"container_result": container_result,
}
def _execute_in_sandbox(self, request: ExecutionRequest, execution_id: str) -> dict[str, Any]:
"""
Execute code in local sandbox environment (fallback)
Args:
request: Execution request
execution_id: Unique execution identifier
Returns:
Dictionary with execution result
"""
try:
# For now, just simulate execution with eval (NOT PRODUCTION SAFE)
# This would be replaced with proper sandbox execution
if request.code.strip().startswith("print"):
# Simple print statement
result = eval(request.code)
return {"success": True, "output": str(result)}
else:
# For safety, don't execute arbitrary code in this demo
return {"success": False, "error": "Code execution not implemented in demo mode"}
except Exception as e:
return {"success": False, "error": f"Execution error: {str(e)}"}
def _log_execution(
self,
request: ExecutionRequest,
result: ExecutionResult,
risk_assessment: RiskAssessment | None = None,
resource_usage: ResourceUsage | None = None,
) -> str:
"""
Internal method to log execution
Args:
request: Execution request
result: Execution result
risk_assessment: Risk analysis
resource_usage: Resource usage
Returns:
Audit entry ID
"""
# Prepare execution result for logging
execution_result = {
"success": result.success,
"output": result.output,
"error": result.error,
"execution_time": result.execution_time,
}
# Prepare risk assessment for logging
risk_data = None
if risk_assessment:
risk_data = {
"score": risk_assessment.score,
"patterns": [
{
"pattern": p.pattern,
"severity": p.severity,
"score": p.score,
"line_number": p.line_number,
"description": p.description,
}
for p in risk_assessment.patterns
],
"safe_to_execute": risk_assessment.safe_to_execute,
"approval_required": risk_assessment.approval_required,
}
# Prepare resource usage for logging
usage_data = None
if resource_usage:
usage_data = {
"cpu_percent": resource_usage.cpu_percent,
"memory_percent": resource_usage.memory_percent,
"memory_used_gb": resource_usage.memory_used_gb,
"elapsed_seconds": resource_usage.elapsed_seconds,
"approaching_limits": resource_usage.approaching_limits,
}
return self.audit_logger.log_execution(
code=request.code,
execution_result=execution_result,
risk_assessment=risk_data,
resource_usage=usage_data,
)

View File

@@ -1,337 +0,0 @@
"""
Resource Enforcement for Mai Sandbox System
Provides percentage-based resource limit enforcement
building on existing Phase 1 monitoring infrastructure.
"""
import sys
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@dataclass
class ResourceLimits:
"""Resource limit configuration"""
cpu_percent: float
memory_percent: float
timeout_seconds: int
network_bandwidth_mbps: float | None = None
@dataclass
class ResourceUsage:
"""Current resource usage statistics"""
cpu_percent: float
memory_percent: float
memory_used_gb: float
elapsed_seconds: float
approaching_limits: dict[str, bool]
class ResourceEnforcer:
"""
Enforces resource limits for sandbox execution.
Builds on Phase 1 ResourceDetector for percentage-based limits.
"""
def __init__(self):
"""Initialize resource enforcer"""
# Try to import existing resource monitoring from Phase 1
try:
sys.path.append(str(Path(__file__).parent.parent / "model"))
from resource_detector import ResourceDetector
self.resource_detector = ResourceDetector()
except ImportError:
# Fallback implementation
self.resource_detector = None
self.current_limits: ResourceLimits | None = None
self.start_time: float | None = None
self.timeout_timer: threading.Timer | None = None
self.monitoring_active: bool = False
def set_cpu_limit(self, percent: float) -> float:
"""
Calculate CPU limit as percentage of available resources
Args:
percent: Desired CPU limit (0-100)
Returns:
Actual CPU limit percentage
"""
if not 0 <= percent <= 100:
raise ValueError("CPU percent must be between 0 and 100")
# Calculate effective limit
cpu_limit = min(percent, 100.0)
return cpu_limit
def set_memory_limit(self, percent: float) -> float:
"""
Calculate memory limit as percentage of available resources
Args:
percent: Desired memory limit (0-100)
Returns:
Actual memory limit percentage
"""
if not 0 <= percent <= 100:
raise ValueError("Memory percent must be between 0 and 100")
# Calculate effective limit
if self.resource_detector:
try:
resource_info = self.resource_detector.get_current_usage()
memory_limit = min(
percent,
resource_info.memory_percent
+ (resource_info.memory_available_gb / resource_info.memory_total_gb * 100),
)
return memory_limit
except Exception:
pass
# Fallback
memory_limit = min(percent, 100.0)
return memory_limit
def set_limits(self, limits: ResourceLimits) -> bool:
"""
Set comprehensive resource limits
Args:
limits: ResourceLimits configuration
Returns:
True if limits were successfully set
"""
try:
self.current_limits = limits
return True
except Exception as e:
print(f"Failed to set limits: {e}")
return False
def enforce_timeout(self, seconds: int) -> bool:
"""
Enforce execution timeout using signal alarm
Args:
seconds: Timeout in seconds
Returns:
True if timeout was set successfully
"""
try:
if self.timeout_timer:
self.timeout_timer.cancel()
# Create timeout handler
def timeout_handler():
raise TimeoutError(f"Execution exceeded {seconds} second timeout")
# Set timer (cross-platform alternative to signal.alarm)
self.timeout_timer = threading.Timer(seconds, timeout_handler)
self.timeout_timer.daemon = True
self.timeout_timer.start()
return True
except Exception as e:
print(f"Failed to set timeout: {e}")
return False
def start_monitoring(self) -> bool:
"""
Start resource monitoring for an execution session
Returns:
True if monitoring started successfully
"""
try:
self.start_time = time.time()
self.monitoring_active = True
return True
except Exception as e:
print(f"Failed to start monitoring: {e}")
return False
def stop_monitoring(self) -> ResourceUsage:
"""
Stop monitoring and return usage statistics
Returns:
ResourceUsage with execution statistics
"""
if not self.monitoring_active:
raise RuntimeError("Monitoring not active")
# Stop timeout timer
if self.timeout_timer:
self.timeout_timer.cancel()
self.timeout_timer = None
# Calculate usage
end_time = time.time()
elapsed = end_time - (self.start_time or 0)
# Get current resource info
cpu_percent = 0.0
memory_percent = 0.0
memory_used_gb = 0.0
memory_total_gb = 0.0
if self.resource_detector:
try:
current_info = self.resource_detector.get_current_usage()
cpu_percent = current_info.cpu_percent
memory_percent = current_info.memory_percent
memory_used_gb = current_info.memory_total_gb - current_info.memory_available_gb
except Exception:
pass # Use fallback values
# Check approaching limits
approaching = {}
if self.current_limits:
approaching["cpu"] = cpu_percent > self.current_limits.cpu_percent * 0.8
approaching["memory"] = memory_percent > self.current_limits.memory_percent * 0.8
approaching["timeout"] = elapsed > self.current_limits.timeout_seconds * 0.8
usage = ResourceUsage(
cpu_percent=cpu_percent,
memory_percent=memory_percent,
memory_used_gb=memory_used_gb,
elapsed_seconds=elapsed,
approaching_limits=approaching,
)
self.monitoring_active = False
return usage
def monitor_usage(self) -> dict[str, Any]:
"""
Get current resource usage statistics
Returns:
Dictionary with current usage metrics
"""
# Get current resource info
cpu_percent = 0.0
memory_percent = 0.0
memory_used_gb = 0.0
memory_available_gb = 0.0
memory_total_gb = 0.0
gpu_available = False
gpu_memory_gb = None
gpu_usage_percent = None
if self.resource_detector:
try:
current_info = self.resource_detector.get_current_usage()
cpu_percent = current_info.cpu_percent
memory_percent = current_info.memory_percent
memory_used_gb = current_info.memory_total_gb - current_info.memory_available_gb
memory_available_gb = current_info.memory_available_gb
memory_total_gb = current_info.memory_total_gb
gpu_available = current_info.gpu_available
gpu_memory_gb = current_info.gpu_memory_gb
gpu_usage_percent = current_info.gpu_usage_percent
except Exception:
pass
usage = {
"cpu_percent": cpu_percent,
"memory_percent": memory_percent,
"memory_used_gb": memory_used_gb,
"memory_available_gb": memory_available_gb,
"memory_total_gb": memory_total_gb,
"gpu_available": gpu_available,
"gpu_memory_gb": gpu_memory_gb,
"gpu_usage_percent": gpu_usage_percent,
"monitoring_active": self.monitoring_active,
}
if self.monitoring_active and self.start_time:
usage["elapsed_seconds"] = time.time() - self.start_time
return usage
def check_limits(self) -> dict[str, bool]:
"""
Check if current usage exceeds or approaches limits
Returns:
Dictionary of limit check results
"""
if not self.current_limits:
return {"limits_set": False}
# Get current resource info
cpu_percent = 0.0
memory_percent = 0.0
if self.resource_detector:
try:
current_info = self.resource_detector.get_current_usage()
cpu_percent = current_info.cpu_percent
memory_percent = current_info.memory_percent
except Exception:
pass
checks = {
"limits_set": True,
"cpu_exceeded": cpu_percent > self.current_limits.cpu_percent,
"memory_exceeded": memory_percent > self.current_limits.memory_percent,
"cpu_approaching": cpu_percent > self.current_limits.cpu_percent * 0.8,
"memory_approaching": memory_percent > self.current_limits.memory_percent * 0.8,
}
if self.monitoring_active and self.start_time:
elapsed = time.time() - self.start_time
checks["timeout_exceeded"] = elapsed > self.current_limits.timeout_seconds
checks["timeout_approaching"] = elapsed > self.current_limits.timeout_seconds * 0.8
return checks
def graceful_degradation_warning(self) -> str | None:
"""
Generate warning if approaching resource limits
Returns:
Warning message or None if safe
"""
checks = self.check_limits()
if not checks["limits_set"]:
return None
warnings = []
if checks["cpu_approaching"]:
warnings.append(f"CPU usage approaching limit ({self.current_limits.cpu_percent}%)")
if checks["memory_approaching"]:
warnings.append(
f"Memory usage approaching limit ({self.current_limits.memory_percent}%)"
)
if self.monitoring_active and self.start_time:
elapsed = time.time() - self.start_time
if elapsed > self.current_limits.timeout_seconds * 0.8:
warnings.append(
f"Execution approaching timeout ({self.current_limits.timeout_seconds}s)"
)
if warnings:
return "Warning: " + "; ".join(warnings) + ". Consider reducing execution scope."
return None

View File

@@ -1,260 +0,0 @@
"""
Risk Analysis for Mai Sandbox System
Provides AST-based code analysis to detect dangerous patterns
and calculate risk scores for code execution decisions.
"""
import ast
import re
from dataclasses import dataclass
@dataclass
class RiskPattern:
"""Represents a detected risky code pattern"""
pattern: str
severity: str # 'BLOCKED', 'HIGH', 'MEDIUM', 'LOW'
score: int
line_number: int
description: str
@dataclass
class RiskAssessment:
"""Result of risk analysis"""
score: int
patterns: list[RiskPattern]
safe_to_execute: bool
approval_required: bool
class RiskAnalyzer:
"""
Analyzes code for dangerous patterns using AST parsing
and static analysis techniques.
"""
# Severity scores and risk thresholds
SEVERITY_SCORES = {"BLOCKED": 100, "HIGH": 80, "MEDIUM": 50, "LOW": 20}
# Known dangerous patterns
DANGEROUS_IMPORTS = {
"os.system": ("BLOCKED", "Direct system command execution"),
"os.popen": ("BLOCKED", "Direct system command execution"),
"subprocess.run": ("HIGH", "Subprocess execution"),
"subprocess.call": ("HIGH", "Subprocess execution"),
"subprocess.Popen": ("HIGH", "Subprocess execution"),
"eval": ("HIGH", "Dynamic code execution"),
"exec": ("HIGH", "Dynamic code execution"),
"compile": ("MEDIUM", "Code compilation"),
"__import__": ("MEDIUM", "Dynamic import"),
"open": ("LOW", "File access"),
"shutil.rmtree": ("HIGH", "Directory deletion"),
"os.remove": ("HIGH", "File deletion"),
"os.unlink": ("HIGH", "File deletion"),
"os.mkdir": ("LOW", "Directory creation"),
"os.chdir": ("MEDIUM", "Directory change"),
}
# Regex patterns for additional checks
REGEX_PATTERNS = [
(r"/dev/[^\\s]+", "BLOCKED", "Device file access"),
(r"rm\\s+-rf\\s+/", "BLOCKED", "Recursive root deletion"),
(r"shell=True", "HIGH", "Shell execution in subprocess"),
(r"password", "MEDIUM", "Potential password handling"),
(r"api[_-]?key", "MEDIUM", "Potential API key handling"),
(r"chmod\\s+777", "HIGH", "Permissive file permissions"),
(r"sudo\\s+", "HIGH", "Privilege escalation"),
]
def __init__(self):
"""Initialize risk analyzer"""
self.reset_analysis()
def reset_analysis(self):
"""Reset analysis state"""
self.detected_patterns: list[RiskPattern] = []
def analyze_ast(self, code: str) -> RiskAssessment:
"""
Analyze Python code using AST parsing
Args:
code: Python source code to analyze
Returns:
RiskAssessment with score, patterns, and execution decision
"""
self.reset_analysis()
try:
tree = ast.parse(code)
self._walk_ast(tree)
except SyntaxError as e:
# Syntax errors are automatically high risk
pattern = RiskPattern(
pattern="syntax_error",
severity="HIGH",
score=90,
line_number=getattr(e, "lineno", 0),
description=f"Syntax error: {e}",
)
self.detected_patterns.append(pattern)
# Additional regex-based checks
self._regex_checks(code)
# Calculate overall assessment
total_score = max([p.score for p in self.detected_patterns] + [0])
assessment = RiskAssessment(
score=total_score,
patterns=self.detected_patterns.copy(),
safe_to_execute=total_score < 50,
approval_required=total_score >= 30,
)
return assessment
def detect_dangerous_patterns(self, code: str) -> list[RiskPattern]:
"""
Detect dangerous patterns using both AST and regex analysis
Args:
code: Python source code
Returns:
List of detected RiskPattern objects
"""
assessment = self.analyze_ast(code)
return assessment.patterns
def calculate_risk_score(self, patterns: list[RiskPattern]) -> int:
"""
Calculate overall risk score from detected patterns
Args:
patterns: List of detected risk patterns
Returns:
Overall risk score (0-100)
"""
if not patterns:
return 0
return max([p.score for p in patterns])
def _walk_ast(self, tree: ast.AST):
"""Walk AST tree and detect dangerous patterns"""
for node in ast.walk(tree):
self._check_imports(node)
self._check_function_calls(node)
self._check_file_operations(node)
def _check_imports(self, node: ast.AST):
"""Check for dangerous imports"""
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.name
if name in self.DANGEROUS_IMPORTS:
severity, desc = self.DANGEROUS_IMPORTS[name]
pattern = RiskPattern(
pattern=f"import_{name}",
severity=severity,
score=self.SEVERITY_SCORES[severity],
line_number=getattr(node, "lineno", 0),
description=f"Import of {desc}",
)
self.detected_patterns.append(pattern)
elif isinstance(node, ast.ImportFrom):
if node.module and node.module in self.DANGEROUS_IMPORTS:
name = node.module
severity, desc = self.DANGEROUS_IMPORTS[name]
pattern = RiskPattern(
pattern=f"from_{name}",
severity=severity,
score=self.SEVERITY_SCORES[severity],
line_number=getattr(node, "lineno", 0),
description=f"Import from {desc}",
)
self.detected_patterns.append(pattern)
def _check_function_calls(self, node: ast.AST):
"""Check for dangerous function calls"""
if isinstance(node, ast.Call):
# Get function name
func_name = self._get_function_name(node.func)
if func_name in self.DANGEROUS_IMPORTS:
severity, desc = self.DANGEROUS_IMPORTS[func_name]
pattern = RiskPattern(
pattern=f"call_{func_name}",
severity=severity,
score=self.SEVERITY_SCORES[severity],
line_number=getattr(node, "lineno", 0),
description=f"Call to {desc}",
)
self.detected_patterns.append(pattern)
# Check for shell=True in subprocess calls
if func_name in ["subprocess.run", "subprocess.call", "subprocess.Popen"]:
for keyword in node.keywords:
if keyword.arg == "shell" and isinstance(keyword.value, ast.Constant):
if keyword.value.value is True:
pattern = RiskPattern(
pattern="shell_true",
severity="HIGH",
score=self.SEVERITY_SCORES["HIGH"],
line_number=getattr(node, "lineno", 0),
description="Shell execution in subprocess",
)
self.detected_patterns.append(pattern)
def _check_file_operations(self, node: ast.AST):
"""Check for dangerous file operations"""
if isinstance(node, ast.Call):
func_name = self._get_function_name(node.func)
dangerous_file_ops = ["shutil.rmtree", "os.remove", "os.unlink", "os.chmod", "os.chown"]
if func_name in dangerous_file_ops:
severity = "HIGH" if "rmtree" in func_name else "MEDIUM"
pattern = RiskPattern(
pattern=f"file_{func_name}",
severity=severity,
score=self.SEVERITY_SCORES[severity],
line_number=getattr(node, "lineno", 0),
description=f"Dangerous file operation: {func_name}",
)
self.detected_patterns.append(pattern)
def _get_function_name(self, node: ast.AST) -> str:
"""Extract function name from AST node"""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
attr = []
while isinstance(node, ast.Attribute):
attr.append(node.attr)
node = node.value
if isinstance(node, ast.Name):
attr.append(node.id)
return ".".join(reversed(attr))
return ""
def _regex_checks(self, code: str):
"""Perform regex-based pattern detection"""
lines = code.split("\\n")
for pattern_str, severity, description in self.REGEX_PATTERNS:
for line_num, line in enumerate(lines, 1):
if re.search(pattern_str, line, re.IGNORECASE):
pattern = RiskPattern(
pattern=pattern_str,
severity=severity,
score=self.SEVERITY_SCORES[severity],
line_number=line_num,
description=f"Regex detected: {description}",
)
self.detected_patterns.append(pattern)

View File

@@ -1,378 +0,0 @@
"""
Tests for Docker Executor component
Test suite for Docker-based container execution with isolation,
resource limits, and audit logging integration.
"""
import pytest
import tempfile
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
# Import components under test
from src.mai.sandbox.docker_executor import DockerExecutor, ContainerConfig, ContainerResult
from src.mai.sandbox.audit_logger import AuditLogger
class TestContainerConfig:
"""Test ContainerConfig dataclass"""
def test_default_config(self):
"""Test default configuration values"""
config = ContainerConfig()
assert config.image == "python:3.10-slim"
assert config.timeout_seconds == 30
assert config.memory_limit == "128m"
assert config.cpu_limit == "0.5"
assert config.network_disabled is True
assert config.read_only_filesystem is True
assert config.tmpfs_size == "64m"
assert config.working_dir == "/app"
assert config.user == "nobody"
def test_custom_config(self):
"""Test custom configuration values"""
config = ContainerConfig(
image="python:3.9-alpine",
timeout_seconds=60,
memory_limit="256m",
cpu_limit="0.8",
network_disabled=False,
)
assert config.image == "python:3.9-alpine"
assert config.timeout_seconds == 60
assert config.memory_limit == "256m"
assert config.cpu_limit == "0.8"
assert config.network_disabled is False
class TestDockerExecutor:
"""Test DockerExecutor class"""
@pytest.fixture
def mock_audit_logger(self):
"""Create mock audit logger"""
return Mock(spec=AuditLogger)
@pytest.fixture
def docker_executor(self, mock_audit_logger):
"""Create DockerExecutor instance for testing"""
return DockerExecutor(audit_logger=mock_audit_logger)
def test_init_without_docker(self, mock_audit_logger):
"""Test initialization when Docker is not available"""
with patch("src.mai.sandbox.docker_executor.DOCKER_AVAILABLE", False):
executor = DockerExecutor(audit_logger=mock_audit_logger)
assert executor.is_available() is False
assert executor.client is None
def test_init_with_docker_error(self, mock_audit_logger):
"""Test initialization when Docker fails to connect"""
with patch("src.mai.sandbox.docker_executor.DOCKER_AVAILABLE", True):
with patch("docker.from_env") as mock_from_env:
mock_from_env.side_effect = Exception("Docker daemon not running")
executor = DockerExecutor(audit_logger=mock_audit_logger)
assert executor.is_available() is False
assert executor.client is None
def test_is_available(self, docker_executor):
"""Test is_available method"""
# When client is None, should not be available
docker_executor.client = None
docker_executor.available = False
assert docker_executor.is_available() is False
# When client is available, should reflect available status
docker_executor.client = Mock()
docker_executor.available = True
assert docker_executor.is_available() is True
docker_executor.client = Mock()
docker_executor.available = False
assert docker_executor.is_available() is False
def test_execute_code_unavailable(self, docker_executor):
"""Test execute_code when Docker is not available"""
with patch.object(docker_executor, "is_available", return_value=False):
result = docker_executor.execute_code("print('test')")
assert result.success is False
assert result.container_id == ""
assert result.exit_code == -1
assert "Docker executor not available" in result.error
@patch("src.mai.sandbox.docker_executor.Path")
@patch("src.mai.sandbox.docker_executor.tempfile.TemporaryDirectory")
def test_execute_code_success(self, mock_temp_dir, mock_path, docker_executor):
"""Test successful code execution in container"""
# Mock temporary directory and file creation
mock_temp_file = Mock()
mock_temp_file.write_text = Mock()
mock_temp_path = Mock()
mock_temp_path.__truediv__ = Mock(return_value=mock_temp_file)
mock_temp_path.__str__ = Mock(return_value="/tmp/test")
mock_temp_dir.return_value.__enter__.return_value = mock_temp_path
# Mock Docker client and container
mock_container = Mock()
mock_container.id = "test-container-id"
mock_container.wait.return_value = {"StatusCode": 0}
mock_container.logs.return_value = b"test output"
mock_container.stats.return_value = {
"cpu_stats": {"cpu_usage": {"total_usage": 1000000}, "system_cpu_usage": 2000000},
"precpu_stats": {"cpu_usage": {"total_usage": 500000}, "system_cpu_usage": 1000000},
"memory_stats": {"usage": 50000000, "limit": 100000000},
}
mock_client = Mock()
mock_client.containers.run.return_value = mock_container
docker_executor.client = mock_client
docker_executor.available = True
# Execute code
result = docker_executor.execute_code("print('test')")
assert result.success is True
assert result.container_id == "test-container-id"
assert result.exit_code == 0
assert result.stdout == "test output"
assert result.execution_time > 0
assert result.resource_usage is not None
@patch("src.mai.sandbox.docker_executor.Path")
@patch("src.mai.sandbox.docker_executor.tempfile.TemporaryDirectory")
def test_execute_code_with_files(self, mock_temp_dir, mock_path, docker_executor):
"""Test code execution with additional files"""
# Mock temporary directory and file creation
mock_temp_file = Mock()
mock_temp_file.write_text = Mock()
mock_temp_path = Mock()
mock_temp_path.__truediv__ = Mock(return_value=mock_temp_file)
mock_temp_path.__str__ = Mock(return_value="/tmp/test")
mock_temp_dir.return_value.__enter__.return_value = mock_temp_path
# Mock Docker client and container
mock_container = Mock()
mock_container.id = "test-container-id"
mock_container.wait.return_value = {"StatusCode": 0}
mock_container.logs.return_value = b"test output"
mock_container.stats.return_value = {}
mock_client = Mock()
mock_client.containers.run.return_value = mock_container
docker_executor.client = mock_client
docker_executor.available = True
# Execute code with files
files = {"data.txt": "test data"}
result = docker_executor.execute_code("print('test')", files=files)
# Verify additional files were handled
assert mock_temp_file.write_text.call_count >= 2 # code + data file
assert result.success is True
def test_build_container_config(self, docker_executor):
"""Test building Docker container configuration"""
config = ContainerConfig(memory_limit="256m", cpu_limit="0.8", network_disabled=False)
environment = {"TEST_VAR": "test_value"}
container_config = docker_executor._build_container_config(config, environment)
assert container_config["mem_limit"] == "256m"
assert container_config["cpu_quota"] == 80000 # 0.8 * 100000
assert container_config["cpu_period"] == 100000
assert container_config["network_disabled"] is False
assert container_config["read_only"] is True
assert container_config["user"] == "nobody"
assert container_config["working_dir"] == "/app"
assert "TEST_VAR" in container_config["environment"]
assert "security_opt" in container_config
assert "cap_drop" in container_config
assert "cap_add" in container_config
def test_get_container_stats(self, docker_executor):
"""Test extracting container resource statistics"""
mock_container = Mock()
mock_container.stats.return_value = {
"cpu_stats": {
"cpu_usage": {"total_usage": 2000000},
"system_cpu_usage": 4000000,
"online_cpus": 2,
},
"precpu_stats": {"cpu_usage": {"total_usage": 1000000}, "system_cpu_usage": 2000000},
"memory_stats": {
"usage": 67108864, # 64MB
"limit": 134217728, # 128MB
},
}
stats = docker_executor._get_container_stats(mock_container)
assert stats["cpu_percent"] == 100.0 # (2000000-1000000)/(4000000-2000000) * 2 * 100
assert stats["memory_usage_bytes"] == 67108864
assert stats["memory_limit_bytes"] == 134217728
assert stats["memory_percent"] == 50.0
assert stats["memory_usage_mb"] == 64.0
def test_get_container_stats_error(self, docker_executor):
"""Test get_container_stats with error"""
mock_container = Mock()
mock_container.stats.side_effect = Exception("Stats error")
stats = docker_executor._get_container_stats(mock_container)
assert stats["cpu_percent"] == 0.0
assert stats["memory_usage_bytes"] == 0
assert stats["memory_percent"] == 0.0
assert stats["memory_usage_mb"] == 0.0
def test_log_container_execution(self, docker_executor, mock_audit_logger):
"""Test logging container execution"""
config = ContainerConfig(image="python:3.10-slim")
result = ContainerResult(
success=True,
container_id="test-id",
exit_code=0,
stdout="test output",
stderr="",
execution_time=1.5,
resource_usage={"cpu_percent": 50.0},
)
docker_executor._log_container_execution("print('test')", result, config)
# Verify audit logger was called
mock_audit_logger.log_execution.assert_called_once()
call_args = mock_audit_logger.log_execution.call_args
assert call_args.kwargs["code"] == "print('test')"
assert call_args.kwargs["execution_type"] == "docker"
assert "docker_container" in call_args.kwargs["execution_result"]["type"]
def test_get_available_images(self, docker_executor):
"""Test getting available Docker images"""
mock_image = Mock()
mock_image.tags = ["python:3.10-slim", "python:3.9-alpine"]
mock_client = Mock()
mock_client.images.list.return_value = [mock_image]
docker_executor.client = mock_client
docker_executor.available = True
images = docker_executor.get_available_images()
assert "python:3.10-slim" in images
assert "python:3.9-alpine" in images
def test_pull_image(self, docker_executor):
"""Test pulling Docker image"""
mock_client = Mock()
mock_client.images.pull.return_value = None
docker_executor.client = mock_client
docker_executor.available = True
result = docker_executor.pull_image("python:3.10-slim")
assert result is True
mock_client.images.pull.assert_called_once_with("python:3.10-slim")
def test_cleanup_containers(self, docker_executor):
"""Test cleaning up containers"""
mock_container = Mock()
mock_client = Mock()
mock_client.containers.list.return_value = [mock_container, mock_container]
docker_executor.client = mock_client
docker_executor.available = True
count = docker_executor.cleanup_containers()
assert count == 2
assert mock_container.remove.call_count == 2
def test_get_system_info(self, docker_executor):
"""Test getting Docker system information"""
mock_client = Mock()
mock_client.info.return_value = {
"Containers": 5,
"ContainersRunning": 2,
"Images": 10,
"MemTotal": 8589934592,
"NCPU": 4,
}
mock_client.version.return_value = {"Version": "20.10.7", "ApiVersion": "1.41"}
docker_executor.client = mock_client
docker_executor.available = True
info = docker_executor.get_system_info()
assert info["available"] is True
assert info["version"] == "20.10.7"
assert info["api_version"] == "1.41"
assert info["containers"] == 5
assert info["images"] == 10
class TestDockerExecutorIntegration:
"""Integration tests for Docker executor with other sandbox components"""
@pytest.fixture
def mock_audit_logger(self):
"""Create mock audit logger"""
return Mock(spec=AuditLogger)
def test_docker_executor_integration(self, mock_audit_logger):
"""Test Docker executor integration with audit logger"""
executor = DockerExecutor(audit_logger=mock_audit_logger)
# Test that audit logger is properly integrated
assert executor.audit_logger is mock_audit_logger
# Mock Docker availability for integration test
with patch.object(executor, "is_available", return_value=False):
result = executor.execute_code("print('test')")
# Should fail gracefully and still attempt logging
assert result.success is False
def test_container_result_serialization(self):
"""Test ContainerResult can be properly serialized"""
result = ContainerResult(
success=True,
container_id="test-id",
exit_code=0,
stdout="test output",
stderr="",
execution_time=1.5,
resource_usage={"cpu_percent": 50.0},
)
# Test that result can be converted to dict for JSON serialization
result_dict = {
"success": result.success,
"container_id": result.container_id,
"exit_code": result.exit_code,
"stdout": result.stdout,
"stderr": result.stderr,
"execution_time": result.execution_time,
"error": result.error,
"resource_usage": result.resource_usage,
}
assert result_dict["success"] is True
assert result_dict["container_id"] == "test-id"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,341 +0,0 @@
"""
Integration test for complete Docker sandbox execution
Tests the full integration of Docker executor with sandbox manager,
risk analysis, resource enforcement, and audit logging.
"""
import pytest
import tempfile
import json
from pathlib import Path
from unittest.mock import patch, Mock
from src.mai.sandbox.manager import SandboxManager, ExecutionRequest
from src.mai.sandbox.audit_logger import AuditLogger
@pytest.mark.integration
class TestDockerSandboxIntegration:
"""Integration tests for Docker sandbox execution"""
@pytest.fixture
def temp_log_dir(self):
"""Create temporary directory for audit logs"""
with tempfile.TemporaryDirectory() as temp_dir:
yield temp_dir
@pytest.fixture
def sandbox_manager(self, temp_log_dir):
"""Create SandboxManager with temp log directory"""
return SandboxManager(log_dir=temp_log_dir)
def test_full_docker_execution_workflow(self, sandbox_manager):
"""Test complete Docker execution workflow"""
with patch.object(sandbox_manager.docker_executor, "is_available", return_value=True):
with patch.object(sandbox_manager.docker_executor, "execute_code") as mock_docker:
# Mock Docker container execution
from src.mai.sandbox.docker_executor import ContainerResult
mock_docker.return_value = {
"success": True,
"output": "42\n",
"container_result": ContainerResult(
success=True,
container_id="integration-test-container",
exit_code=0,
stdout="42\n",
stderr="",
execution_time=2.3,
resource_usage={
"cpu_percent": 15.2,
"memory_usage_mb": 28.5,
"memory_percent": 5.5,
},
),
}
# Create execution request
request = ExecutionRequest(
code="result = 6 * 7\nprint(result)",
use_docker=True,
docker_image="python:3.10-slim",
timeout_seconds=30,
cpu_limit_percent=50.0,
memory_limit_percent=40.0,
network_allowed=False,
filesystem_restricted=True,
)
# Execute code
result = sandbox_manager.execute_code(request)
# Verify execution results
assert result.success is True
assert result.execution_method == "docker"
assert result.output == "42\n"
assert result.container_result is not None
assert result.container_result.container_id == "integration-test-container"
assert result.container_result.exit_code == 0
assert result.container_result.execution_time == 2.3
assert result.container_result.resource_usage["cpu_percent"] == 15.2
assert result.container_result.resource_usage["memory_usage_mb"] == 28.5
# Verify Docker executor was called with correct parameters
mock_docker.assert_called_once()
call_args = mock_docker.call_args
# Check code was passed correctly
assert call_args.args[0] == "result = 6 * 7\nprint(result)"
# Check container config
config = call_args.kwargs["config"]
assert config.image == "python:3.10-slim"
assert config.timeout_seconds == 30
assert config.memory_limit == "51m" # Scaled from 40% of 128m
assert config.cpu_limit == "0.5" # 50% CPU
assert config.network_disabled is True
assert config.read_only_filesystem is True
# Verify audit logging occurred
assert result.audit_entry_id is not None
# Check audit log contents
logs = sandbox_manager.get_execution_history(limit=1)
assert len(logs) == 1
log_entry = logs[0]
assert log_entry["code"] == "result = 6 * 7\nprint(result)"
assert log_entry["execution_result"]["success"] is True
assert "docker_container" in log_entry["execution_result"]
def test_docker_execution_with_additional_files(self, sandbox_manager):
"""Test Docker execution with additional files"""
with patch.object(sandbox_manager.docker_executor, "is_available", return_value=True):
with patch.object(sandbox_manager.docker_executor, "execute_code") as mock_docker:
# Mock Docker execution
from src.mai.sandbox.docker_executor import ContainerResult
mock_docker.return_value = {
"success": True,
"output": "Hello, Alice!\n",
"container_result": ContainerResult(
success=True,
container_id="files-test-container",
exit_code=0,
stdout="Hello, Alice!\n",
),
}
# Create execution request with additional files
request = ExecutionRequest(
code="with open('template.txt', 'r') as f: template = f.read()\nprint(template.replace('{name}', 'Alice'))",
use_docker=True,
additional_files={"template.txt": "Hello, {name}!"},
)
# Execute code
result = sandbox_manager.execute_code(request)
# Verify execution
assert result.success is True
assert result.execution_method == "docker"
# Verify Docker executor was called with files
call_args = mock_docker.call_args
assert "files" in call_args.kwargs
assert call_args.kwargs["files"] == {"template.txt": "Hello, {name}!"}
def test_docker_execution_blocked_by_risk_analysis(self, sandbox_manager):
"""Test that high-risk code is blocked before Docker execution"""
with patch.object(sandbox_manager.docker_executor, "is_available", return_value=True):
with patch.object(sandbox_manager.docker_executor, "execute_code") as mock_docker:
# Risk analysis will automatically detect the dangerous pattern
request = ExecutionRequest(
code="import subprocess; subprocess.run(['rm', '-rf', '/'], shell=True)",
use_docker=True,
)
# Execute code
result = sandbox_manager.execute_code(request)
# Verify execution was blocked
assert result.success is False
assert "blocked" in result.error.lower()
assert result.risk_assessment.score >= 70
assert result.execution_method == "local" # Set before Docker check
# Docker executor should not be called
mock_docker.assert_not_called()
# Should still be logged
assert result.audit_entry_id is not None
def test_docker_execution_fallback_to_local(self, sandbox_manager):
"""Test fallback to local execution when Docker unavailable"""
with patch.object(sandbox_manager.docker_executor, "is_available", return_value=False):
with patch.object(sandbox_manager, "_execute_in_sandbox") as mock_local:
with patch.object(
sandbox_manager.resource_enforcer, "stop_monitoring"
) as mock_monitoring:
# Mock local execution
mock_local.return_value = {"success": True, "output": "Local fallback result"}
# Mock resource usage
from src.mai.sandbox.resource_enforcer import ResourceUsage
mock_monitoring.return_value = ResourceUsage(
cpu_percent=35.0,
memory_percent=25.0,
memory_used_gb=0.4,
elapsed_seconds=1.8,
approaching_limits=False,
)
# Create request preferring Docker
request = ExecutionRequest(
code="print('fallback test')",
use_docker=True, # But Docker is unavailable
)
# Execute code
result = sandbox_manager.execute_code(request)
# Verify fallback to local execution
assert result.success is True
assert result.execution_method == "local"
assert result.output == "Local fallback result"
assert result.container_result is None
assert result.resource_usage is not None
assert result.resource_usage.cpu_percent == 35.0
# Verify local execution was used
mock_local.assert_called_once()
def test_audit_logging_docker_execution_details(self, sandbox_manager):
"""Test comprehensive audit logging for Docker execution"""
with patch.object(sandbox_manager.docker_executor, "is_available", return_value=True):
with patch.object(sandbox_manager.docker_executor, "execute_code") as mock_docker:
# Mock Docker execution with detailed stats
from src.mai.sandbox.docker_executor import ContainerResult
mock_docker.return_value = {
"success": True,
"output": "Calculation complete: 144\n",
"container_result": ContainerResult(
success=True,
container_id="audit-test-container",
exit_code=0,
stdout="Calculation complete: 144\n",
stderr="",
execution_time=3.7,
resource_usage={
"cpu_percent": 22.8,
"memory_usage_mb": 45.2,
"memory_percent": 8.9,
"memory_usage_bytes": 47395648,
"memory_limit_bytes": 536870912,
},
),
}
# Execute request
request = ExecutionRequest(
code="result = 12 * 12\nprint(f'Calculation complete: {result}')",
use_docker=True,
docker_image="python:3.9-alpine",
timeout_seconds=45,
)
result = sandbox_manager.execute_code(request)
# Verify audit log contains Docker execution details
logs = sandbox_manager.get_execution_history(limit=1)
assert len(logs) == 1
log_entry = logs[0]
execution_result = log_entry["execution_result"]
# Check Docker-specific fields
assert execution_result["type"] == "docker_container"
assert execution_result["container_id"] == "audit-test-container"
assert execution_result["exit_code"] == 0
assert execution_result["stdout"] == "Calculation complete: 144\n"
# Check configuration details
config = execution_result["config"]
assert config["image"] == "python:3.9-alpine"
assert config["timeout"] == 45
assert config["network_disabled"] is True
assert config["read_only_filesystem"] is True
# Check resource usage
resource_usage = execution_result["resource_usage"]
assert resource_usage["cpu_percent"] == 22.8
assert resource_usage["memory_usage_mb"] == 45.2
assert resource_usage["memory_percent"] == 8.9
def test_system_status_includes_docker_info(self, sandbox_manager):
"""Test system status includes Docker information"""
with patch.object(sandbox_manager.docker_executor, "is_available", return_value=True):
with patch.object(
sandbox_manager.docker_executor, "get_system_info"
) as mock_docker_info:
# Mock Docker system info
mock_docker_info.return_value = {
"available": True,
"version": "20.10.12",
"api_version": "1.41",
"containers": 5,
"containers_running": 2,
"images": 8,
"ncpu": 4,
"memory_total": 8589934592,
}
# Get system status
status = sandbox_manager.get_system_status()
# Verify Docker information is included
assert "docker_available" in status
assert "docker_info" in status
assert status["docker_available"] is True
assert status["docker_info"]["available"] is True
assert status["docker_info"]["version"] == "20.10.12"
assert status["docker_info"]["containers"] == 5
assert status["docker_info"]["images"] == 8
def test_docker_status_management(self, sandbox_manager):
"""Test Docker status management functions"""
with patch.object(sandbox_manager.docker_executor, "is_available", return_value=True):
with patch.object(
sandbox_manager.docker_executor, "get_available_images"
) as mock_images:
with patch.object(sandbox_manager.docker_executor, "pull_image") as mock_pull:
with patch.object(
sandbox_manager.docker_executor, "cleanup_containers"
) as mock_cleanup:
# Mock responses
mock_images.return_value = ["python:3.10-slim", "python:3.9-alpine"]
mock_pull.return_value = True
mock_cleanup.return_value = 3
# Test get Docker status
status = sandbox_manager.get_docker_status()
assert status["available"] is True
assert "python:3.10-slim" in status["images"]
assert "python:3.9-alpine" in status["images"]
# Test pull image
pull_result = sandbox_manager.pull_docker_image("node:16-alpine")
assert pull_result is True
mock_pull.assert_called_once_with("node:16-alpine")
# Test cleanup containers
cleanup_count = sandbox_manager.cleanup_docker_containers()
assert cleanup_count == 3
mock_cleanup.assert_called_once()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,632 +0,0 @@
#!/usr/bin/env python3
"""
Comprehensive integration tests for Phase 1 requirements.
This module validates all Phase 1 components work together correctly.
Tests cover model discovery, resource monitoring, model selection,
context compression, git workflow, and end-to-end conversations.
"""
import unittest
import os
import sys
import time
import tempfile
import shutil
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
# Add src to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
# Mock missing dependencies first
sys.modules["ollama"] = Mock()
sys.modules["psutil"] = Mock()
sys.modules["tiktoken"] = Mock()
# Test availability of core components
def check_imports():
"""Check if all required imports are available."""
test_results = {}
# Test each import
imports_to_test = [
("mai.core.interface", "MaiInterface"),
("mai.model.resource_detector", "ResourceDetector"),
("mai.model.compression", "ContextCompressor"),
("mai.core.config", "Config"),
("mai.core.exceptions", "MaiError"),
("mai.git.workflow", "StagingWorkflow"),
("mai.git.committer", "AutoCommitter"),
("mai.git.health_check", "HealthChecker"),
]
for module_name, class_name in imports_to_test:
try:
module = __import__(module_name, fromlist=[class_name])
cls = getattr(module, class_name)
test_results[f"{module_name}.{class_name}"] = "OK"
except ImportError as e:
test_results[f"{module_name}.{class_name}"] = f"IMPORT_ERROR: {e}"
except AttributeError as e:
test_results[f"{module_name}.{class_name}"] = f"CLASS_NOT_FOUND: {e}"
return test_results
class TestComponentImports(unittest.TestCase):
"""Test that all Phase 1 components can be imported."""
def test_all_components_import(self):
"""Test that all required components can be imported."""
results = check_imports()
# Print results for debugging
print("\n=== Import Test Results ===")
for component, status in results.items():
print(f"{component}: {status}")
# Check that at least some imports work
successful_imports = sum(1 for status in results.values() if status == "OK")
self.assertGreater(
successful_imports, 0, "At least one component should import successfully"
)
class TestResourceDetectionBasic(unittest.TestCase):
"""Test basic resource detection functionality."""
def test_resource_info_structure(self):
"""Test that ResourceInfo has required structure."""
try:
from mai.model.resource_detector import ResourceInfo
# Create a test ResourceInfo with correct attributes
resources = ResourceInfo(
cpu_percent=50.0,
memory_total_gb=16.0,
memory_available_gb=8.0,
memory_percent=50.0,
gpu_available=False,
)
self.assertEqual(resources.cpu_percent, 50.0)
self.assertEqual(resources.memory_total_gb, 16.0)
self.assertEqual(resources.memory_available_gb, 8.0)
self.assertEqual(resources.memory_percent, 50.0)
self.assertEqual(resources.gpu_available, False)
except ImportError:
self.skipTest("ResourceDetector not available")
def test_resource_detector_basic(self):
"""Test ResourceDetector can be instantiated."""
try:
from mai.model.resource_detector import ResourceDetector
detector = ResourceDetector()
self.assertIsNotNone(detector)
except ImportError:
self.skipTest("ResourceDetector not available")
class TestContextCompressionBasic(unittest.TestCase):
"""Test basic context compression functionality."""
def test_context_compressor_instantiation(self):
"""Test ContextCompressor can be instantiated."""
try:
from mai.model.compression import ContextCompressor
compressor = ContextCompressor()
self.assertIsNotNone(compressor)
except ImportError:
self.skipTest("ContextCompressor not available")
def test_token_counting_basic(self):
"""Test basic token counting functionality."""
try:
from mai.model.compression import ContextCompressor, TokenInfo
compressor = ContextCompressor()
tokens = compressor.count_tokens("Hello, world!")
self.assertIsInstance(tokens, TokenInfo)
self.assertGreater(tokens.count, 0)
self.assertIsInstance(tokens.model_name, str)
self.assertGreater(len(tokens.model_name), 0)
self.assertIsInstance(tokens.accuracy, float)
self.assertGreaterEqual(tokens.accuracy, 0.0)
self.assertLessEqual(tokens.accuracy, 1.0)
except (ImportError, AttributeError):
self.skipTest("ContextCompressor not fully available")
def test_token_info_structure(self):
"""Test TokenInfo object structure and attributes."""
try:
from mai.model.compression import ContextCompressor, TokenInfo
compressor = ContextCompressor()
tokens = compressor.count_tokens("Test string for structure validation")
# Test TokenInfo structure
self.assertIsInstance(tokens, TokenInfo)
self.assertTrue(hasattr(tokens, "count"))
self.assertTrue(hasattr(tokens, "model_name"))
self.assertTrue(hasattr(tokens, "accuracy"))
# Test attribute types
self.assertIsInstance(tokens.count, int)
self.assertIsInstance(tokens.model_name, str)
self.assertIsInstance(tokens.accuracy, float)
# Test attribute values
self.assertGreaterEqual(tokens.count, 0)
self.assertGreater(len(tokens.model_name), 0)
self.assertGreaterEqual(tokens.accuracy, 0.0)
self.assertLessEqual(tokens.accuracy, 1.0)
except (ImportError, AttributeError):
self.skipTest("ContextCompressor not fully available")
def test_token_counting_accuracy(self):
"""Test token counting accuracy for various text lengths."""
try:
from mai.model.compression import ContextCompressor
compressor = ContextCompressor()
# Test with different text lengths
test_cases = [
("", 0, 5), # Empty string
("Hello", 1, 10), # Short text
("Hello, world! This is a test.", 5, 15), # Medium text
(
"This is a longer text to test token counting accuracy across multiple sentences and paragraphs. "
* 3,
50,
200,
), # Long text
]
for text, min_expected, max_expected in test_cases:
with self.subTest(text_length=len(text)):
tokens = compressor.count_tokens(text)
self.assertGreaterEqual(
tokens.count,
min_expected,
f"Token count {tokens.count} below minimum {min_expected} for text: {text[:50]}...",
)
self.assertLessEqual(
tokens.count,
max_expected,
f"Token count {tokens.count} above maximum {max_expected} for text: {text[:50]}...",
)
# Test accuracy is reasonable
self.assertGreaterEqual(tokens.accuracy, 0.7, "Accuracy should be at least 70%")
self.assertLessEqual(tokens.accuracy, 1.0, "Accuracy should not exceed 100%")
except (ImportError, AttributeError):
self.skipTest("ContextCompressor not fully available")
def test_token_fallback_behavior(self):
"""Test token counting fallback behavior when tiktoken unavailable."""
try:
from mai.model.compression import ContextCompressor
from unittest.mock import patch
compressor = ContextCompressor()
test_text = "Testing fallback behavior with a reasonable text length"
# Test normal behavior first
tokens_normal = compressor.count_tokens(test_text)
self.assertIsInstance(tokens_normal, type(tokens_normal))
self.assertGreater(tokens_normal.count, 0)
# Test with mocked tiktoken error to trigger fallback
with patch("tiktoken.encoding_for_model") as mock_encoding:
mock_encoding.side_effect = Exception("tiktoken not available")
tokens_fallback = compressor.count_tokens(test_text)
# Both should return TokenInfo objects
self.assertEqual(type(tokens_normal), type(tokens_fallback))
self.assertIsInstance(tokens_fallback, type(tokens_fallback))
self.assertGreater(tokens_fallback.count, 0)
# Fallback might be less accurate but should still be reasonable
self.assertGreaterEqual(tokens_fallback.accuracy, 0.7)
self.assertLessEqual(tokens_fallback.accuracy, 1.0)
except (ImportError, AttributeError):
self.skipTest("ContextCompressor not fully available")
def test_token_edge_cases(self):
"""Test token counting with edge cases."""
try:
from mai.model.compression import ContextCompressor
compressor = ContextCompressor()
# Edge cases to test
edge_cases = [
("", "Empty string"),
(" ", "Single space"),
("\n", "Single newline"),
("\t", "Single tab"),
(" ", "Multiple spaces"),
("Hello\nworld", "Text with newline"),
("Special chars: !@#$%^&*()", "Special characters"),
("Unicode: ñáéíóú 🤖", "Unicode characters"),
("Numbers: 1234567890", "Numbers"),
("Mixed: Hello123!@#world", "Mixed content"),
]
for text, description in edge_cases:
with self.subTest(case=description):
tokens = compressor.count_tokens(text)
# All should return TokenInfo
self.assertIsInstance(tokens, type(tokens))
self.assertGreaterEqual(
tokens.count, 0, f"Token count should be >= 0 for {description}"
)
# Model name and accuracy should be set
self.assertGreater(
len(tokens.model_name),
0,
f"Model name should not be empty for {description}",
)
self.assertGreaterEqual(
tokens.accuracy, 0.7, f"Accuracy should be reasonable for {description}"
)
self.assertLessEqual(
tokens.accuracy, 1.0, f"Accuracy should not exceed 100% for {description}"
)
except (ImportError, AttributeError):
self.skipTest("ContextCompressor not fully available")
class TestConfigSystem(unittest.TestCase):
"""Test configuration system functionality."""
def test_config_instantiation(self):
"""Test Config can be instantiated."""
try:
from mai.core.config import Config
config = Config()
self.assertIsNotNone(config)
except ImportError:
self.skipTest("Config not available")
def test_config_validation(self):
"""Test configuration validation."""
try:
from mai.core.config import Config
config = Config()
# Test basic validation
self.assertIsNotNone(config)
except ImportError:
self.skipTest("Config not available")
class TestGitWorkflowBasic(unittest.TestCase):
"""Test basic git workflow functionality."""
def test_staging_workflow_instantiation(self):
"""Test StagingWorkflow can be instantiated."""
try:
from mai.git.workflow import StagingWorkflow
workflow = StagingWorkflow()
self.assertIsNotNone(workflow)
except ImportError:
self.skipTest("StagingWorkflow not available")
def test_auto_committer_instantiation(self):
"""Test AutoCommitter can be instantiated."""
try:
from mai.git.committer import AutoCommitter
committer = AutoCommitter()
self.assertIsNotNone(committer)
except ImportError:
self.skipTest("AutoCommitter not available")
def test_health_checker_instantiation(self):
"""Test HealthChecker can be instantiated."""
try:
from mai.git.health_check import HealthChecker
checker = HealthChecker()
self.assertIsNotNone(checker)
except ImportError:
self.skipTest("HealthChecker not available")
class TestExceptionHandling(unittest.TestCase):
"""Test exception handling system."""
def test_exception_hierarchy(self):
"""Test exception hierarchy exists."""
try:
from mai.core.exceptions import (
MaiError,
ModelError,
ConfigurationError,
ModelConnectionError,
)
# Test exception inheritance
self.assertTrue(issubclass(ModelError, MaiError))
self.assertTrue(issubclass(ConfigurationError, MaiError))
self.assertTrue(issubclass(ModelConnectionError, ModelError))
# Test instantiation
error = MaiError("Test error")
self.assertEqual(str(error), "Test error")
except ImportError:
self.skipTest("Exception hierarchy not available")
class TestFileStructure(unittest.TestCase):
"""Test that all required files exist with proper structure."""
def test_core_files_exist(self):
"""Test that all core files exist."""
required_files = [
"src/mai/core/interface.py",
"src/mai/model/ollama_client.py",
"src/mai/model/resource_detector.py",
"src/mai/model/compression.py",
"src/mai/core/config.py",
"src/mai/core/exceptions.py",
"src/mai/git/workflow.py",
"src/mai/git/committer.py",
"src/mai/git/health_check.py",
]
project_root = os.path.dirname(os.path.dirname(__file__))
for file_path in required_files:
full_path = os.path.join(project_root, file_path)
self.assertTrue(os.path.exists(full_path), f"Required file {file_path} does not exist")
def test_minimum_file_sizes(self):
"""Test that files meet minimum size requirements."""
min_lines = 40 # From plan requirements
test_file = os.path.join(os.path.dirname(__file__), "test_integration.py")
with open(test_file, "r") as f:
lines = f.readlines()
self.assertGreaterEqual(
len(lines), min_lines, f"Integration test file must have at least {min_lines} lines"
)
class TestPhase1Requirements(unittest.TestCase):
"""Test that Phase 1 requirements are satisfied."""
def test_requirement_1_model_discovery(self):
"""Requirement 1: Model discovery and capability detection."""
try:
from mai.core.interface import MaiInterface
# Test interface has list_models method
interface = MaiInterface()
self.assertTrue(hasattr(interface, "list_models"))
except ImportError:
self.skipTest("MaiInterface not available")
def test_requirement_2_resource_monitoring(self):
"""Requirement 2: Resource monitoring and constraint detection."""
try:
from mai.model.resource_detector import ResourceDetector
detector = ResourceDetector()
self.assertTrue(hasattr(detector, "detect_resources"))
except ImportError:
self.skipTest("ResourceDetector not available")
def test_requirement_3_model_selection(self):
"""Requirement 3: Intelligent model selection."""
try:
from mai.core.interface import MaiInterface
interface = MaiInterface()
# Should have model selection capability
self.assertIsNotNone(interface)
except ImportError:
self.skipTest("MaiInterface not available")
def test_requirement_4_context_compression(self):
"""Requirement 4: Context compression for model switching."""
try:
from mai.model.compression import ContextCompressor
compressor = ContextCompressor()
self.assertTrue(hasattr(compressor, "count_tokens"))
except ImportError:
self.skipTest("ContextCompressor not available")
def test_requirement_5_git_integration(self):
"""Requirement 5: Git workflow automation."""
# Check if GitPython is available
try:
import git
except ImportError:
self.skipTest("GitPython not available - git integration tests skipped")
git_components = [
("mai.git.workflow", "StagingWorkflow"),
("mai.git.committer", "AutoCommitter"),
("mai.git.health_check", "HealthChecker"),
]
available_count = 0
for module_name, class_name in git_components:
try:
module = __import__(module_name, fromlist=[class_name])
cls = getattr(module, class_name)
available_count += 1
except ImportError:
pass
# At least one git component should be available if GitPython is installed
# If GitPython is installed but no components are available, that's a problem
if available_count == 0:
# Check if the source files actually exist
import os
from pathlib import Path
src_path = Path(__file__).parent.parent / "src" / "mai" / "git"
if src_path.exists():
git_files = list(src_path.glob("*.py"))
if git_files:
self.fail(
f"Git files exist but no git components importable. Files: {[f.name for f in git_files]}"
)
return
# If we get here, either components are available or they don't exist yet
# Both are acceptable states for Phase 1 validation
self.assertTrue(True, "Git integration validation completed")
class TestErrorHandlingGracefulDegradation(unittest.TestCase):
"""Test error handling and graceful degradation."""
def test_missing_dependency_handling(self):
"""Test handling of missing dependencies."""
# Mock missing ollama dependency
with patch.dict("sys.modules", {"ollama": None}):
try:
from mai.model.ollama_client import OllamaClient
# If import succeeds, test that it handles missing dependency
client = OllamaClient()
self.assertIsNotNone(client)
except ImportError:
# Expected behavior - import should fail gracefully
pass
def test_resource_exhaustion_simulation(self):
"""Test behavior with simulated resource exhaustion."""
try:
from mai.model.resource_detector import ResourceInfo
# Create exhausted resource scenario with correct attributes
exhausted = ResourceInfo(
cpu_percent=95.0,
memory_total_gb=16.0,
memory_available_gb=0.1, # Very low (100MB)
memory_percent=99.4, # Almost all memory used
gpu_available=False,
)
# ResourceInfo should handle extreme values
self.assertEqual(exhausted.cpu_percent, 95.0)
self.assertEqual(exhausted.memory_available_gb, 0.1)
self.assertEqual(exhausted.memory_percent, 99.4)
except ImportError:
self.skipTest("ResourceInfo not available")
class TestPerformanceRegression(unittest.TestCase):
"""Test performance regression detection."""
def test_import_time_performance(self):
"""Test that import time is reasonable."""
import_time_start = time.time()
# Try to import main components
try:
from mai.core.config import Config
from mai.core.exceptions import MaiError
config = Config()
except ImportError:
pass
import_time = time.time() - import_time_start
# Imports should complete within reasonable time (< 5 seconds)
self.assertLess(import_time, 5.0, "Import time should be reasonable")
def test_instantiation_performance(self):
"""Test that component instantiation is performant."""
times = []
# Test multiple instantiations
for _ in range(5):
start_time = time.time()
try:
from mai.core.config import Config
config = Config()
except ImportError:
pass
times.append(time.time() - start_time)
avg_time = sum(times) / len(times)
# Average instantiation should be fast (< 1 second)
self.assertLess(avg_time, 1.0, "Component instantiation should be fast")
def run_phase1_validation():
"""Run comprehensive Phase 1 validation."""
print("\n" + "=" * 60)
print("PHASE 1 INTEGRATION TEST VALIDATION")
print("=" * 60)
# Run import checks
import_results = check_imports()
print("\n1. COMPONENT IMPORT VALIDATION:")
for component, status in import_results.items():
status_symbol = "" if status == "OK" else ""
print(f" {status_symbol} {component}: {status}")
# Count successful imports
successful = sum(1 for s in import_results.values() if s == "OK")
total = len(import_results)
print(f"\n Import Success Rate: {successful}/{total} ({successful / total * 100:.1f}%)")
# Run unit tests
print("\n2. FUNCTIONAL TESTS:")
loader = unittest.TestLoader()
suite = loader.loadTestsFromModule(sys.modules[__name__])
runner = unittest.TextTestRunner(verbosity=1)
result = runner.run(suite)
# Summary
print("\n" + "=" * 60)
print("PHASE 1 VALIDATION SUMMARY")
print("=" * 60)
print(f"Tests run: {result.testsRun}")
print(f"Failures: {len(result.failures)}")
print(f"Errors: {len(result.errors)}")
print(f"Skipped: {len(result.skipped)}")
success_rate = (
(result.testsRun - len(result.failures) - len(result.errors)) / result.testsRun * 100
)
print(f"Success Rate: {success_rate:.1f}%")
if success_rate >= 80:
print("✓ PHASE 1 VALIDATION: PASSED")
else:
print("✗ PHASE 1 VALIDATION: FAILED")
return result.wasSuccessful()
if __name__ == "__main__":
# Run Phase 1 validation
success = run_phase1_validation()
sys.exit(0 if success else 1)

View File

@@ -1,351 +0,0 @@
"""
Comprehensive test suite for Mai Memory System
Tests all memory components including storage, compression, retrieval, and CLI integration.
"""
import pytest
import tempfile
import shutil
import os
import sys
import time
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime, timedelta
# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
# Import CLI interface - this should work
from mai.core.interface import show_memory_status, search_memory, manage_memory
# Try to import memory components - they might not work due to dependencies
try:
from mai.memory.storage import MemoryStorage, MemoryStorageError
from mai.memory.compression import MemoryCompressor, CompressionResult
from mai.memory.retrieval import ContextRetriever, SearchQuery, MemoryContext
from mai.memory.manager import MemoryManager, MemoryStats
from mai.models.conversation import Conversation, Message
from mai.models.memory import MemoryContext as ModelMemoryContext
MEMORY_COMPONENTS_AVAILABLE = True
except ImportError as e:
print(f"Memory components not available: {e}")
MEMORY_COMPONENTS_AVAILABLE = False
class TestCLIInterface:
"""Test CLI interface functions - these should always work."""
def test_show_memory_status(self):
"""Test show_memory_status CLI function."""
result = show_memory_status()
assert result is not None
assert isinstance(result, dict)
# Should contain memory status information
if "memory_enabled" in result:
assert isinstance(result["memory_enabled"], bool)
if "error" in result:
# Memory system might not be initialized, that's okay for test
assert isinstance(result["error"], str)
def test_search_memory(self):
"""Test search_memory CLI function."""
result = search_memory("test query")
assert result is not None
assert isinstance(result, dict)
if "success" in result:
assert isinstance(result["success"], bool)
if "results" in result:
assert isinstance(result["results"], list)
if "error" in result:
# Memory system might not be initialized, that's okay for test
assert isinstance(result["error"], str)
def test_manage_memory(self):
"""Test manage_memory CLI function."""
# Test stats action (should work even without memory system)
result = manage_memory("stats")
assert result is not None
assert isinstance(result, dict)
assert result.get("action") == "stats"
if "success" in result:
assert isinstance(result["success"], bool)
if "error" in result:
# Memory system might not be initialized, that's okay for test
assert isinstance(result["error"], str)
def test_manage_memory_unknown_action(self):
"""Test manage_memory with unknown action."""
result = manage_memory("unknown_action")
assert result is not None
assert isinstance(result, dict)
assert result.get("success") is False
# Check if error mentions unknown action or memory system not available
error_msg = result.get("error", "").lower()
assert "unknown" in error_msg or "memory system not available" in error_msg
@pytest.mark.skipif(not MEMORY_COMPONENTS_AVAILABLE, reason="Memory components not available")
class TestMemoryStorage:
"""Test memory storage functionality."""
@pytest.fixture
def temp_db(self):
"""Create temporary database for testing."""
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_memory.db")
yield db_path
shutil.rmtree(temp_dir, ignore_errors=True)
def test_storage_initialization(self, temp_db):
"""Test that storage initializes correctly."""
try:
storage = MemoryStorage(database_path=temp_db)
assert storage is not None
except Exception as e:
# Storage might fail due to missing dependencies
pytest.skip(f"Storage initialization failed: {e}")
def test_conversation_storage(self, temp_db):
"""Test storing and retrieving conversations."""
try:
storage = MemoryStorage(database_path=temp_db)
# Create test conversation with minimal required fields
conversation = Conversation(
title="Test Conversation",
messages=[
Message(role="user", content="Hello", timestamp=datetime.now()),
Message(role="assistant", content="Hi there!", timestamp=datetime.now()),
],
created_at=datetime.now(),
updated_at=datetime.now(),
)
# Store conversation
conv_id = storage.store_conversation(conversation)
assert conv_id is not None
except Exception as e:
pytest.skip(f"Conversation storage test failed: {e}")
def test_conversation_search(self, temp_db):
"""Test searching conversations."""
try:
storage = MemoryStorage(database_path=temp_db)
# Store test conversations
conv1 = Conversation(
title="Python Programming",
messages=[
Message(role="user", content="How to use Python?", timestamp=datetime.now())
],
created_at=datetime.now(),
updated_at=datetime.now(),
)
conv2 = Conversation(
title="Machine Learning",
messages=[Message(role="user", content="What is ML?", timestamp=datetime.now())],
created_at=datetime.now(),
updated_at=datetime.now(),
)
storage.store_conversation(conv1)
storage.store_conversation(conv2)
# Search for Python
results = storage.search_conversations("Python", limit=10)
assert isinstance(results, list)
except Exception as e:
pytest.skip(f"Conversation search test failed: {e}")
@pytest.mark.skipif(not MEMORY_COMPONENTS_AVAILABLE, reason="Memory components not available")
class TestMemoryCompression:
"""Test memory compression functionality."""
@pytest.fixture
def compressor(self):
"""Create compressor instance."""
try:
return MemoryCompressor()
except Exception as e:
pytest.skip(f"Compressor initialization failed: {e}")
def test_conversation_compression(self, compressor):
"""Test conversation compression."""
try:
# Create test conversation
conversation = Conversation(
title="Long Conversation",
messages=[
Message(role="user", content=f"Message {i}", timestamp=datetime.now())
for i in range(10) # Smaller for testing
],
created_at=datetime.now(),
updated_at=datetime.now(),
)
# Compress
result = compressor.compress_conversation(conversation)
assert result is not None
except Exception as e:
pytest.skip(f"Conversation compression test failed: {e}")
@pytest.mark.skipif(not MEMORY_COMPONENTS_AVAILABLE, reason="Memory components not available")
class TestMemoryManager:
"""Test memory manager orchestration."""
@pytest.fixture
def temp_manager(self):
"""Create memory manager with temporary storage."""
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_manager.db")
try:
# Mock the storage path
with patch("mai.memory.manager.MemoryStorage") as mock_storage:
mock_storage.return_value = MemoryStorage(database_path=db_path)
manager = MemoryManager()
yield manager
except Exception as e:
# If manager fails, create a mock
mock_manager = Mock(spec=MemoryManager)
mock_manager.get_memory_stats.return_value = MemoryStats()
mock_manager.store_conversation.return_value = "test-conv-id"
mock_manager.get_context.return_value = ModelMemoryContext(
relevant_conversations=[], total_conversations=0, estimated_tokens=0, metadata={}
)
mock_manager.search_conversations.return_value = []
yield mock_manager
shutil.rmtree(temp_dir, ignore_errors=True)
def test_conversation_storage(self, temp_manager):
"""Test conversation storage through manager."""
try:
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
conv_id = temp_manager.store_conversation(messages=messages, metadata={"test": True})
assert conv_id is not None
assert isinstance(conv_id, str)
except Exception as e:
pytest.skip(f"Manager conversation storage test failed: {e}")
def test_memory_stats(self, temp_manager):
"""Test memory statistics through manager."""
try:
stats = temp_manager.get_memory_stats()
assert stats is not None
assert isinstance(stats, MemoryStats)
except Exception as e:
pytest.skip(f"Manager memory stats test failed: {e}")
@pytest.mark.skipif(not MEMORY_COMPONENTS_AVAILABLE, reason="Memory components not available")
class TestContextRetrieval:
"""Test context retrieval functionality."""
@pytest.fixture
def retriever(self):
"""Create retriever instance."""
try:
return ContextRetriever()
except Exception as e:
pytest.skip(f"Retriever initialization failed: {e}")
def test_context_retrieval(self, retriever):
"""Test context retrieval for query."""
try:
query = SearchQuery(text="Python programming", max_results=5)
context = retriever.get_context(query)
assert context is not None
assert isinstance(context, ModelMemoryContext)
except Exception as e:
pytest.skip(f"Context retrieval test failed: {e}")
class TestIntegration:
"""Integration tests for memory system."""
def test_end_to_end_workflow(self):
"""Test complete workflow: store -> search -> compress."""
# This is a smoke test to verify the basic workflow doesn't crash
# Individual components are tested in their respective test classes
# Test CLI functions don't crash
status = show_memory_status()
assert isinstance(status, dict)
search_result = search_memory("test")
assert isinstance(search_result, dict)
manage_result = manage_memory("stats")
assert isinstance(manage_result, dict)
# Performance and stress tests
class TestPerformance:
"""Performance tests for memory system."""
def test_search_performance(self):
"""Test search performance with larger datasets."""
try:
# This would require setting up a larger test dataset
# For now, just verify the function exists and returns reasonable timing
start_time = time.time()
result = search_memory("performance test")
end_time = time.time()
search_time = end_time - start_time
assert search_time < 5.0 # Should complete within 5 seconds
assert isinstance(result, dict)
except ImportError:
pytest.skip("Memory system dependencies not available")
def test_memory_stats_performance(self):
"""Test memory stats calculation performance."""
try:
start_time = time.time()
result = show_memory_status()
end_time = time.time()
stats_time = end_time - start_time
assert stats_time < 2.0 # Should complete within 2 seconds
assert isinstance(result, dict)
except ImportError:
pytest.skip("Memory system dependencies not available")
if __name__ == "__main__":
# Run tests if script is executed directly
pytest.main([__file__, "-v"])

View File

@@ -1,409 +0,0 @@
"""
Test suite for ApprovalSystem
This module provides comprehensive testing for the risk-based approval system
including user interaction, trust management, and edge cases.
"""
import pytest
import time
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime
from typing import Dict, Any
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from mai.sandbox.approval_system import (
ApprovalSystem,
RiskLevel,
ApprovalResult,
RiskAnalysis,
ApprovalRequest,
ApprovalDecision,
)
class TestApprovalSystem:
"""Test cases for ApprovalSystem."""
@pytest.fixture
def approval_system(self):
"""Create fresh ApprovalSystem for each test."""
with patch("mai.sandbox.approval_system.get_config") as mock_config:
mock_config.return_value = Mock()
mock_config.return_value.get.return_value = {
"low_threshold": 0.3,
"medium_threshold": 0.6,
"high_threshold": 0.8,
}
return ApprovalSystem()
@pytest.fixture
def mock_low_risk_code(self):
"""Sample low-risk code."""
return 'print("hello world")'
@pytest.fixture
def mock_medium_risk_code(self):
"""Sample medium-risk code."""
return "import os\nprint(os.getcwd())"
@pytest.fixture
def mock_high_risk_code(self):
"""Sample high-risk code."""
return 'import subprocess\nsubprocess.call(["ls", "-la"])'
@pytest.fixture
def mock_blocked_code(self):
"""Sample blocked code."""
return 'os.system("rm -rf /")'
def test_initialization(self, approval_system):
"""Test ApprovalSystem initialization."""
assert approval_system.approval_history == []
assert approval_system.user_preferences == {}
assert approval_system.trust_patterns == {}
assert approval_system.risk_thresholds["low_threshold"] == 0.3
def test_risk_analysis_low_risk(self, approval_system, mock_low_risk_code):
"""Test risk analysis for low-risk code."""
context = {}
risk_analysis = approval_system._analyze_code_risk(mock_low_risk_code, context)
assert risk_analysis.risk_level == RiskLevel.LOW
assert risk_analysis.severity_score < 0.3
assert len(risk_analysis.reasons) == 0
assert risk_analysis.confidence > 0.5
def test_risk_analysis_medium_risk(self, approval_system, mock_medium_risk_code):
"""Test risk analysis for medium-risk code."""
context = {}
risk_analysis = approval_system._analyze_code_risk(mock_medium_risk_code, context)
assert risk_analysis.risk_level == RiskLevel.MEDIUM
assert risk_analysis.severity_score >= 0.3
assert len(risk_analysis.reasons) > 0
assert "file_system" in risk_analysis.affected_resources
def test_risk_analysis_high_risk(self, approval_system, mock_high_risk_code):
"""Test risk analysis for high-risk code."""
context = {}
risk_analysis = approval_system._analyze_code_risk(mock_high_risk_code, context)
assert risk_analysis.risk_level == RiskLevel.HIGH
assert risk_analysis.severity_score >= 0.6
assert len(risk_analysis.reasons) > 0
assert "system_operations" in risk_analysis.affected_resources
def test_risk_analysis_blocked(self, approval_system, mock_blocked_code):
"""Test risk analysis for blocked code."""
context = {}
risk_analysis = approval_system._analyze_code_risk(mock_blocked_code, context)
assert risk_analysis.risk_level == RiskLevel.BLOCKED
assert any("blocked operation" in reason.lower() for reason in risk_analysis.reasons)
def test_operation_type_detection(self, approval_system):
"""Test operation type detection."""
assert approval_system._get_operation_type('print("hello")') == "output_operation"
assert approval_system._get_operation_type("import os") == "module_import"
assert approval_system._get_operation_type('os.system("ls")') == "system_command"
assert approval_system._get_operation_type('open("file.txt")') == "file_operation"
assert approval_system._get_operation_type("x = 5") == "code_execution"
def test_request_id_generation(self, approval_system):
"""Test unique request ID generation."""
code1 = 'print("test")'
code2 = 'print("test")'
id1 = approval_system._generate_request_id(code1)
time.sleep(0.01) # Small delay to ensure different timestamps
id2 = approval_system._generate_request_id(code2)
assert id1 != id2 # Should be different due to timestamp
assert len(id1) == 12 # MD5 hash truncated to 12 chars
assert len(id2) == 12
@patch("builtins.input")
def test_low_risk_approval_allow(self, mock_input, approval_system, mock_low_risk_code):
"""Test low-risk approval with user allowing."""
mock_input.return_value = "y"
result, decision = approval_system.request_approval(mock_low_risk_code)
assert result == ApprovalResult.APPROVED
assert decision.user_input == "allowed"
assert decision.request.risk_analysis.risk_level == RiskLevel.LOW
@patch("builtins.input")
def test_low_risk_approval_deny(self, mock_input, approval_system, mock_low_risk_code):
"""Test low-risk approval with user denying."""
mock_input.return_value = "n"
result, decision = approval_system.request_approval(mock_low_risk_code)
assert result == ApprovalResult.DENIED
assert decision.user_input == "denied"
@patch("builtins.input")
def test_low_risk_approval_always(self, mock_input, approval_system, mock_low_risk_code):
"""Test low-risk approval with 'always allow' preference."""
mock_input.return_value = "a"
result, decision = approval_system.request_approval(mock_low_risk_code)
assert result == ApprovalResult.APPROVED
assert decision.user_input == "allowed_always"
assert decision.trust_updated == True
assert "output_operation" in approval_system.user_preferences
@patch("builtins.input")
def test_medium_risk_approval_details(self, mock_input, approval_system, mock_medium_risk_code):
"""Test medium-risk approval requesting details."""
mock_input.return_value = "d" # Request details first
with patch.object(approval_system, "_present_detailed_view") as mock_detailed:
mock_detailed.return_value = "allowed"
result, decision = approval_system.request_approval(mock_medium_risk_code)
assert result == ApprovalResult.APPROVED
mock_detailed.assert_called_once()
@patch("builtins.input")
def test_high_risk_approval_confirm(self, mock_input, approval_system, mock_high_risk_code):
"""Test high-risk approval with confirmation."""
mock_input.return_value = "confirm"
result, decision = approval_system.request_approval(mock_high_risk_code)
assert result == ApprovalResult.APPROVED
assert decision.request.risk_analysis.risk_level == RiskLevel.HIGH
@patch("builtins.input")
def test_high_risk_approval_cancel(self, mock_input, approval_system, mock_high_risk_code):
"""Test high-risk approval with cancellation."""
mock_input.return_value = "cancel"
result, decision = approval_system.request_approval(mock_high_risk_code)
assert result == ApprovalResult.DENIED
@patch("builtins.print")
def test_blocked_operation(self, mock_print, approval_system, mock_blocked_code):
"""Test blocked operation handling."""
result, decision = approval_system.request_approval(mock_blocked_code)
assert result == ApprovalResult.BLOCKED
assert decision.request.risk_analysis.risk_level == RiskLevel.BLOCKED
def test_auto_approval_for_trusted_operation(self, approval_system, mock_low_risk_code):
"""Test auto-approval for trusted operations."""
# Set up user preference
approval_system.user_preferences["output_operation"] = "auto_allow"
result, decision = approval_system.request_approval(mock_low_risk_code)
assert result == ApprovalResult.ALLOWED
assert decision.user_input == "auto_allowed"
def test_approval_history(self, approval_system, mock_low_risk_code):
"""Test approval history tracking."""
# Add some decisions
with patch("builtins.input", return_value="y"):
approval_system.request_approval(mock_low_risk_code)
approval_system.request_approval(mock_low_risk_code)
history = approval_system.get_approval_history(5)
assert len(history) == 2
assert all(isinstance(decision, ApprovalDecision) for decision in history)
def test_trust_patterns_learning(self, approval_system, mock_low_risk_code):
"""Test trust pattern learning."""
# Add approved decisions
with patch("builtins.input", return_value="y"):
for _ in range(3):
approval_system.request_approval(mock_low_risk_code)
patterns = approval_system.get_trust_patterns()
assert "output_operation" in patterns
assert patterns["output_operation"] == 3
def test_preferences_reset(self, approval_system):
"""Test preferences reset."""
# Add some preferences
approval_system.user_preferences = {"test": "value"}
approval_system.reset_preferences()
assert approval_system.user_preferences == {}
def test_is_code_safe(self, approval_system, mock_low_risk_code, mock_high_risk_code):
"""Test quick safety check."""
assert approval_system.is_code_safe(mock_low_risk_code) == True
assert approval_system.is_code_safe(mock_high_risk_code) == False
def test_context_awareness(self, approval_system, mock_low_risk_code):
"""Test context-aware risk analysis."""
# New user context should increase risk
context_new_user = {"user_level": "new"}
risk_new = approval_system._analyze_code_risk(mock_low_risk_code, context_new_user)
context_known_user = {"user_level": "known"}
risk_known = approval_system._analyze_code_risk(mock_low_risk_code, context_known_user)
assert risk_new.severity_score > risk_known.severity_score
assert "New user profile" in risk_new.reasons
def test_request_id_uniqueness(self, approval_system):
"""Test that request IDs are unique even for same code."""
code = 'print("test")'
ids = []
for _ in range(10):
rid = approval_system._generate_request_id(code)
assert rid not in ids, f"Duplicate ID: {rid}"
ids.append(rid)
def test_risk_score_accumulation(self, approval_system):
"""Test that multiple risk factors accumulate."""
# Code with multiple risk factors
risky_code = """
import os
import subprocess
os.system("ls")
subprocess.call(["pwd"])
"""
risk_analysis = approval_system._analyze_code_risk(risky_code, {})
assert risk_analysis.severity_score > 0.5
assert len(risk_analysis.reasons) >= 2
assert "system_operations" in risk_analysis.affected_resources
@patch("builtins.input")
def test_detailed_view_presentation(self, mock_input, approval_system, mock_medium_risk_code):
"""Test detailed view presentation."""
mock_input.return_value = "y"
# Create a request
risk_analysis = approval_system._analyze_code_risk(mock_medium_risk_code, {})
request = ApprovalRequest(
code=mock_medium_risk_code,
risk_analysis=risk_analysis,
context={"test": "value"},
timestamp=datetime.now(),
request_id="test123",
)
result = approval_system._present_detailed_view(request)
assert result == "allowed"
@patch("builtins.input")
def test_detailed_analysis_presentation(self, mock_input, approval_system, mock_high_risk_code):
"""Test detailed analysis presentation."""
mock_input.return_value = "confirm"
# Create a request
risk_analysis = approval_system._analyze_code_risk(mock_high_risk_code, {})
request = ApprovalRequest(
code=mock_high_risk_code,
risk_analysis=risk_analysis,
context={},
timestamp=datetime.now(),
request_id="test456",
)
result = approval_system._present_detailed_analysis(request)
assert result == "allowed"
def test_error_handling_in_risk_analysis(self, approval_system):
"""Test error handling in risk analysis."""
# Test with None code (should not crash)
try:
risk_analysis = approval_system._analyze_code_risk(None, {})
# Should still return a valid RiskAnalysis object
assert isinstance(risk_analysis, RiskAnalysis)
except Exception:
# If it raises an exception, that's also acceptable behavior
pass
def test_preferences_persistence(self, approval_system):
"""Test preferences persistence simulation."""
# Simulate loading preferences with error
with patch.object(approval_system, "_load_preferences") as mock_load:
mock_load.side_effect = Exception("Load error")
# Should not crash during initialization
try:
approval_system._load_preferences()
except Exception:
pass # Expected
# Simulate saving preferences with error
with patch.object(approval_system, "_save_preferences") as mock_save:
mock_save.side_effect = Exception("Save error")
# Should not crash when saving
try:
approval_system._save_preferences()
except Exception:
pass # Expected
@pytest.mark.parametrize(
"code_pattern,expected_risk",
[
('print("hello")', RiskLevel.LOW),
("import os", RiskLevel.MEDIUM),
('os.system("ls")', RiskLevel.HIGH),
("rm -rf /", RiskLevel.BLOCKED),
('eval("x + 1")', RiskLevel.HIGH),
('exec("print(1)")', RiskLevel.HIGH),
('__import__("os")', RiskLevel.HIGH),
],
)
def test_risk_patterns(self, approval_system, code_pattern, expected_risk):
"""Test various code patterns for risk classification."""
risk_analysis = approval_system._analyze_code_risk(code_pattern, {})
# Allow some flexibility in risk assessment
if expected_risk == RiskLevel.HIGH:
assert risk_analysis.risk_level in [RiskLevel.HIGH, RiskLevel.BLOCKED]
else:
assert risk_analysis.risk_level == expected_risk
def test_approval_decision_dataclass(self):
"""Test ApprovalDecision dataclass."""
now = datetime.now()
request = ApprovalRequest(
code='print("test")',
risk_analysis=RiskAnalysis(
risk_level=RiskLevel.LOW,
confidence=0.8,
reasons=[],
affected_resources=[],
severity_score=0.1,
),
context={},
timestamp=now,
request_id="test123",
)
decision = ApprovalDecision(
request=request,
result=ApprovalResult.APPROVED,
user_input="y",
timestamp=now,
trust_updated=False,
)
assert decision.request == request
assert decision.result == ApprovalResult.APPROVED
assert decision.user_input == "y"
assert decision.timestamp == now
assert decision.trust_updated == False
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,403 +0,0 @@
"""
Tests for SandboxManager with Docker integration
Test suite for enhanced SandboxManager that includes Docker-based
container execution with fallback to local execution.
"""
import pytest
from unittest.mock import Mock, patch, call
from src.mai.sandbox.manager import SandboxManager, ExecutionRequest, ExecutionResult
from src.mai.sandbox.risk_analyzer import RiskAssessment, RiskPattern
from src.mai.sandbox.resource_enforcer import ResourceUsage, ResourceLimits
from src.mai.sandbox.docker_executor import ContainerResult, ContainerConfig
class TestSandboxManagerDockerIntegration:
"""Test SandboxManager Docker integration features"""
@pytest.fixture
def sandbox_manager(self):
"""Create SandboxManager instance for testing"""
return SandboxManager()
@pytest.fixture
def mock_docker_executor(self):
"""Create mock Docker executor"""
mock_executor = Mock()
mock_executor.is_available.return_value = True
mock_executor.execute_code.return_value = ContainerResult(
success=True,
container_id="test-container-id",
exit_code=0,
stdout="Hello from Docker!",
stderr="",
execution_time=1.2,
resource_usage={"cpu_percent": 45.0, "memory_usage_mb": 32.0},
)
mock_executor.get_system_info.return_value = {
"available": True,
"version": "20.10.7",
"containers": 3,
}
return mock_executor
def test_execution_request_with_docker_options(self):
"""Test ExecutionRequest with Docker-specific options"""
request = ExecutionRequest(
code="print('test')",
use_docker=True,
docker_image="python:3.9-alpine",
timeout_seconds=45,
network_allowed=True,
additional_files={"data.txt": "test content"},
)
assert request.use_docker is True
assert request.docker_image == "python:3.9-alpine"
assert request.timeout_seconds == 45
assert request.network_allowed is True
assert request.additional_files == {"data.txt": "test content"}
def test_execution_result_with_docker_info(self):
"""Test ExecutionResult includes Docker execution info"""
container_result = ContainerResult(
success=True,
container_id="test-id",
exit_code=0,
stdout="Docker output",
execution_time=1.5,
)
result = ExecutionResult(
success=True,
execution_id="test-exec",
output="Docker output",
execution_method="docker",
container_result=container_result,
)
assert result.execution_method == "docker"
assert result.container_result == container_result
assert result.container_result.container_id == "test-id"
def test_execute_code_with_docker_available(self, sandbox_manager):
"""Test code execution when Docker is available"""
with patch.object(sandbox_manager.docker_executor, "is_available", return_value=True):
with patch.object(sandbox_manager.risk_analyzer, "analyze_ast") as mock_risk:
with patch.object(sandbox_manager.docker_executor, "execute_code") as mock_docker:
with patch.object(sandbox_manager.audit_logger, "log_execution") as mock_log:
# Mock risk analysis (allow execution)
mock_risk.return_value = RiskAssessment(
score=20, patterns=[], safe_to_execute=True, approval_required=False
)
# Mock Docker execution
mock_docker.return_value = {
"success": True,
"output": "Hello from Docker!",
"container_result": ContainerResult(
success=True,
container_id="test-container",
exit_code=0,
stdout="Hello from Docker!",
),
}
# Execute request with Docker
request = ExecutionRequest(
code="print('Hello from Docker!')", use_docker=True
)
result = sandbox_manager.execute_code(request)
# Verify Docker was used
assert result.execution_method == "docker"
assert result.success is True
assert result.output == "Hello from Docker!"
assert result.container_result is not None
# Verify Docker executor was called
mock_docker.assert_called_once()
def test_execute_code_fallback_to_local(self, sandbox_manager):
"""Test fallback to local execution when Docker unavailable"""
with patch.object(sandbox_manager.docker_executor, "is_available", return_value=False):
with patch.object(sandbox_manager.risk_analyzer, "analyze_ast") as mock_risk:
with patch.object(sandbox_manager, "_execute_in_sandbox") as mock_local:
with patch.object(
sandbox_manager.resource_enforcer, "stop_monitoring"
) as mock_monitoring:
# Mock risk analysis (allow execution)
mock_risk.return_value = RiskAssessment(
score=20, patterns=[], safe_to_execute=True, approval_required=False
)
# Mock local execution
mock_local.return_value = {"success": True, "output": "Hello from local!"}
# Mock resource monitoring
mock_monitoring.return_value = ResourceUsage(
cpu_percent=25.0,
memory_percent=30.0,
memory_used_gb=0.5,
elapsed_seconds=1.0,
approaching_limits=False,
)
# Execute request preferring Docker
request = ExecutionRequest(
code="print('Hello')",
use_docker=True, # But Docker is unavailable
)
result = sandbox_manager.execute_code(request)
# Verify fallback to local execution
assert result.execution_method == "local"
assert result.success is True
assert result.output == "Hello from local!"
assert result.container_result is None
# Verify local execution was used
mock_local.assert_called_once()
def test_execute_code_local_preference(self, sandbox_manager):
"""Test explicit preference for local execution"""
with patch.object(sandbox_manager.risk_analyzer, "analyze_ast") as mock_risk:
with patch.object(sandbox_manager, "_execute_in_sandbox") as mock_local:
# Mock risk analysis (allow execution)
mock_risk.return_value = RiskAssessment(
score=20, patterns=[], safe_to_execute=True, approval_required=False
)
# Mock local execution
mock_local.return_value = {"success": True, "output": "Local execution"}
# Execute request explicitly preferring local
request = ExecutionRequest(
code="print('Local')",
use_docker=False, # Explicitly prefer local
)
result = sandbox_manager.execute_code(request)
# Verify local execution was used
assert result.execution_method == "local"
assert result.success is True
# Docker executor should not be called
sandbox_manager.docker_executor.execute_code.assert_not_called()
def test_build_docker_config_from_request(self, sandbox_manager):
"""Test building Docker config from execution request"""
from src.mai.sandbox.docker_executor import ContainerConfig
# Use the actual method from DockerExecutor
config = sandbox_manager.docker_executor._build_container_config(
ContainerConfig(
memory_limit="256m", cpu_limit="0.8", network_disabled=False, timeout_seconds=60
),
{"TEST_VAR": "value"},
)
assert config["mem_limit"] == "256m"
assert config["cpu_quota"] == 80000
assert config["network_disabled"] is False
assert config["security_opt"] is not None
assert "TEST_VAR" in config["environment"]
def test_get_docker_status(self, sandbox_manager, mock_docker_executor):
"""Test getting Docker status information"""
sandbox_manager.docker_executor = mock_docker_executor
status = sandbox_manager.get_docker_status()
assert "available" in status
assert "images" in status
assert "system_info" in status
assert status["available"] is True
assert status["system_info"]["available"] is True
def test_pull_docker_image(self, sandbox_manager, mock_docker_executor):
"""Test pulling Docker image"""
sandbox_manager.docker_executor = mock_docker_executor
mock_docker_executor.pull_image.return_value = True
result = sandbox_manager.pull_docker_image("python:3.9-slim")
assert result is True
mock_docker_executor.pull_image.assert_called_once_with("python:3.9-slim")
def test_cleanup_docker_containers(self, sandbox_manager, mock_docker_executor):
"""Test cleaning up Docker containers"""
sandbox_manager.docker_executor = mock_docker_executor
mock_docker_executor.cleanup_containers.return_value = 3
result = sandbox_manager.cleanup_docker_containers()
assert result == 3
mock_docker_executor.cleanup_containers.assert_called_once()
def test_get_system_status_includes_docker(self, sandbox_manager, mock_docker_executor):
"""Test system status includes Docker information"""
sandbox_manager.docker_executor = mock_docker_executor
with patch.object(sandbox_manager, "verify_log_integrity", return_value=True):
status = sandbox_manager.get_system_status()
assert "docker_available" in status
assert "docker_info" in status
assert status["docker_available"] is True
assert status["docker_info"]["available"] is True
def test_execute_code_with_additional_files(self, sandbox_manager):
"""Test code execution with additional files in Docker"""
with patch.object(sandbox_manager.docker_executor, "is_available", return_value=True):
with patch.object(sandbox_manager.risk_analyzer, "analyze_ast") as mock_risk:
with patch.object(sandbox_manager.docker_executor, "execute_code") as mock_docker:
# Mock risk analysis (allow execution)
mock_risk.return_value = RiskAssessment(
score=20, patterns=[], safe_to_execute=True, approval_required=False
)
# Mock Docker execution
mock_docker.return_value = {
"success": True,
"output": "Processed files",
"container_result": ContainerResult(
success=True,
container_id="test-container",
exit_code=0,
stdout="Processed files",
),
}
# Execute request with additional files
request = ExecutionRequest(
code="with open('data.txt', 'r') as f: print(f.read())",
use_docker=True,
additional_files={"data.txt": "test data content"},
)
result = sandbox_manager.execute_code(request)
# Verify Docker executor was called with files
mock_docker.assert_called_once()
call_args = mock_docker.call_args
assert "files" in call_args.kwargs
assert call_args.kwargs["files"] == {"data.txt": "test data content"}
assert result.success is True
assert result.execution_method == "docker"
def test_risk_analysis_blocks_docker_execution(self, sandbox_manager):
"""Test that high-risk code is blocked even with Docker"""
with patch.object(sandbox_manager.risk_analyzer, "analyze_ast") as mock_risk:
# Mock high-risk analysis (block execution)
mock_risk.return_value = RiskAssessment(
score=85,
patterns=[
RiskPattern(
pattern="os.system",
severity="BLOCKED",
score=50,
line_number=1,
description="System command execution",
)
],
safe_to_execute=False,
approval_required=True,
)
# Execute risky code with Docker preference
request = ExecutionRequest(code="os.system('rm -rf /')", use_docker=True)
result = sandbox_manager.execute_code(request)
# Verify execution was blocked
assert result.success is False
assert "blocked" in result.error.lower()
assert result.risk_assessment.score == 85
assert result.execution_method == "local" # Default before Docker check
# Docker should not be called for blocked code
sandbox_manager.docker_executor.execute_code.assert_not_called()
class TestSandboxManagerDockerEdgeCases:
"""Test edge cases and error handling in Docker integration"""
@pytest.fixture
def sandbox_manager(self):
"""Create SandboxManager instance for testing"""
return SandboxManager()
def test_docker_executor_error_handling(self, sandbox_manager):
"""Test handling of Docker executor errors"""
with patch.object(sandbox_manager.docker_executor, "is_available", return_value=True):
with patch.object(sandbox_manager.risk_analyzer, "analyze_ast") as mock_risk:
with patch.object(sandbox_manager.docker_executor, "execute_code") as mock_docker:
# Mock risk analysis (allow execution)
mock_risk.return_value = RiskAssessment(
score=20, patterns=[], safe_to_execute=True, approval_required=False
)
# Mock Docker executor error
mock_docker.return_value = {
"success": False,
"error": "Docker daemon not available",
"container_result": None,
}
request = ExecutionRequest(code="print('test')", use_docker=True)
result = sandbox_manager.execute_code(request)
# Verify error handling
assert result.success is False
assert result.execution_method == "docker"
assert "Docker daemon not available" in result.error
def test_container_resource_usage_integration(self, sandbox_manager):
"""Test integration of container resource usage"""
with patch.object(sandbox_manager.docker_executor, "is_available", return_value=True):
with patch.object(sandbox_manager.risk_analyzer, "analyze_ast") as mock_risk:
with patch.object(sandbox_manager.docker_executor, "execute_code") as mock_docker:
# Mock risk analysis (allow execution)
mock_risk.return_value = RiskAssessment(
score=20, patterns=[], safe_to_execute=True, approval_required=False
)
# Mock Docker execution with resource usage
container_result = ContainerResult(
success=True,
container_id="test-container",
exit_code=0,
stdout="test output",
resource_usage={
"cpu_percent": 35.5,
"memory_usage_mb": 64.2,
"memory_percent": 12.5,
},
)
mock_docker.return_value = {
"success": True,
"output": "test output",
"container_result": container_result,
}
request = ExecutionRequest(code="print('test')", use_docker=True)
result = sandbox_manager.execute_code(request)
# Verify resource usage is preserved
assert result.container_result.resource_usage["cpu_percent"] == 35.5
assert result.container_result.resource_usage["memory_usage_mb"] == 64.2
assert result.container_result.resource_usage["memory_percent"] == 12.5
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,2 +0,0 @@
def test_smoke() -> None:
assert True