410 lines
16 KiB
Python
410 lines
16 KiB
Python
"""
|
|
Test suite for ApprovalSystem
|
|
|
|
This module provides comprehensive testing for the risk-based approval system
|
|
including user interaction, trust management, and edge cases.
|
|
"""
|
|
|
|
import pytest
|
|
import time
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
from datetime import datetime
|
|
from typing import Dict, Any
|
|
|
|
import sys
|
|
import os
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
|
|
|
from mai.sandbox.approval_system import (
|
|
ApprovalSystem,
|
|
RiskLevel,
|
|
ApprovalResult,
|
|
RiskAnalysis,
|
|
ApprovalRequest,
|
|
ApprovalDecision,
|
|
)
|
|
|
|
|
|
class TestApprovalSystem:
|
|
"""Test cases for ApprovalSystem."""
|
|
|
|
@pytest.fixture
|
|
def approval_system(self):
|
|
"""Create fresh ApprovalSystem for each test."""
|
|
with patch("mai.sandbox.approval_system.get_config") as mock_config:
|
|
mock_config.return_value = Mock()
|
|
mock_config.return_value.get.return_value = {
|
|
"low_threshold": 0.3,
|
|
"medium_threshold": 0.6,
|
|
"high_threshold": 0.8,
|
|
}
|
|
return ApprovalSystem()
|
|
|
|
@pytest.fixture
|
|
def mock_low_risk_code(self):
|
|
"""Sample low-risk code."""
|
|
return 'print("hello world")'
|
|
|
|
@pytest.fixture
|
|
def mock_medium_risk_code(self):
|
|
"""Sample medium-risk code."""
|
|
return "import os\nprint(os.getcwd())"
|
|
|
|
@pytest.fixture
|
|
def mock_high_risk_code(self):
|
|
"""Sample high-risk code."""
|
|
return 'import subprocess\nsubprocess.call(["ls", "-la"])'
|
|
|
|
@pytest.fixture
|
|
def mock_blocked_code(self):
|
|
"""Sample blocked code."""
|
|
return 'os.system("rm -rf /")'
|
|
|
|
def test_initialization(self, approval_system):
|
|
"""Test ApprovalSystem initialization."""
|
|
assert approval_system.approval_history == []
|
|
assert approval_system.user_preferences == {}
|
|
assert approval_system.trust_patterns == {}
|
|
assert approval_system.risk_thresholds["low_threshold"] == 0.3
|
|
|
|
def test_risk_analysis_low_risk(self, approval_system, mock_low_risk_code):
|
|
"""Test risk analysis for low-risk code."""
|
|
context = {}
|
|
risk_analysis = approval_system._analyze_code_risk(mock_low_risk_code, context)
|
|
|
|
assert risk_analysis.risk_level == RiskLevel.LOW
|
|
assert risk_analysis.severity_score < 0.3
|
|
assert len(risk_analysis.reasons) == 0
|
|
assert risk_analysis.confidence > 0.5
|
|
|
|
def test_risk_analysis_medium_risk(self, approval_system, mock_medium_risk_code):
|
|
"""Test risk analysis for medium-risk code."""
|
|
context = {}
|
|
risk_analysis = approval_system._analyze_code_risk(mock_medium_risk_code, context)
|
|
|
|
assert risk_analysis.risk_level == RiskLevel.MEDIUM
|
|
assert risk_analysis.severity_score >= 0.3
|
|
assert len(risk_analysis.reasons) > 0
|
|
assert "file_system" in risk_analysis.affected_resources
|
|
|
|
def test_risk_analysis_high_risk(self, approval_system, mock_high_risk_code):
|
|
"""Test risk analysis for high-risk code."""
|
|
context = {}
|
|
risk_analysis = approval_system._analyze_code_risk(mock_high_risk_code, context)
|
|
|
|
assert risk_analysis.risk_level == RiskLevel.HIGH
|
|
assert risk_analysis.severity_score >= 0.6
|
|
assert len(risk_analysis.reasons) > 0
|
|
assert "system_operations" in risk_analysis.affected_resources
|
|
|
|
def test_risk_analysis_blocked(self, approval_system, mock_blocked_code):
|
|
"""Test risk analysis for blocked code."""
|
|
context = {}
|
|
risk_analysis = approval_system._analyze_code_risk(mock_blocked_code, context)
|
|
|
|
assert risk_analysis.risk_level == RiskLevel.BLOCKED
|
|
assert any("blocked operation" in reason.lower() for reason in risk_analysis.reasons)
|
|
|
|
def test_operation_type_detection(self, approval_system):
|
|
"""Test operation type detection."""
|
|
assert approval_system._get_operation_type('print("hello")') == "output_operation"
|
|
assert approval_system._get_operation_type("import os") == "module_import"
|
|
assert approval_system._get_operation_type('os.system("ls")') == "system_command"
|
|
assert approval_system._get_operation_type('open("file.txt")') == "file_operation"
|
|
assert approval_system._get_operation_type("x = 5") == "code_execution"
|
|
|
|
def test_request_id_generation(self, approval_system):
|
|
"""Test unique request ID generation."""
|
|
code1 = 'print("test")'
|
|
code2 = 'print("test")'
|
|
|
|
id1 = approval_system._generate_request_id(code1)
|
|
time.sleep(0.01) # Small delay to ensure different timestamps
|
|
id2 = approval_system._generate_request_id(code2)
|
|
|
|
assert id1 != id2 # Should be different due to timestamp
|
|
assert len(id1) == 12 # MD5 hash truncated to 12 chars
|
|
assert len(id2) == 12
|
|
|
|
@patch("builtins.input")
|
|
def test_low_risk_approval_allow(self, mock_input, approval_system, mock_low_risk_code):
|
|
"""Test low-risk approval with user allowing."""
|
|
mock_input.return_value = "y"
|
|
|
|
result, decision = approval_system.request_approval(mock_low_risk_code)
|
|
|
|
assert result == ApprovalResult.APPROVED
|
|
assert decision.user_input == "allowed"
|
|
assert decision.request.risk_analysis.risk_level == RiskLevel.LOW
|
|
|
|
@patch("builtins.input")
|
|
def test_low_risk_approval_deny(self, mock_input, approval_system, mock_low_risk_code):
|
|
"""Test low-risk approval with user denying."""
|
|
mock_input.return_value = "n"
|
|
|
|
result, decision = approval_system.request_approval(mock_low_risk_code)
|
|
|
|
assert result == ApprovalResult.DENIED
|
|
assert decision.user_input == "denied"
|
|
|
|
@patch("builtins.input")
|
|
def test_low_risk_approval_always(self, mock_input, approval_system, mock_low_risk_code):
|
|
"""Test low-risk approval with 'always allow' preference."""
|
|
mock_input.return_value = "a"
|
|
|
|
result, decision = approval_system.request_approval(mock_low_risk_code)
|
|
|
|
assert result == ApprovalResult.APPROVED
|
|
assert decision.user_input == "allowed_always"
|
|
assert decision.trust_updated == True
|
|
assert "output_operation" in approval_system.user_preferences
|
|
|
|
@patch("builtins.input")
|
|
def test_medium_risk_approval_details(self, mock_input, approval_system, mock_medium_risk_code):
|
|
"""Test medium-risk approval requesting details."""
|
|
mock_input.return_value = "d" # Request details first
|
|
|
|
with patch.object(approval_system, "_present_detailed_view") as mock_detailed:
|
|
mock_detailed.return_value = "allowed"
|
|
|
|
result, decision = approval_system.request_approval(mock_medium_risk_code)
|
|
|
|
assert result == ApprovalResult.APPROVED
|
|
mock_detailed.assert_called_once()
|
|
|
|
@patch("builtins.input")
|
|
def test_high_risk_approval_confirm(self, mock_input, approval_system, mock_high_risk_code):
|
|
"""Test high-risk approval with confirmation."""
|
|
mock_input.return_value = "confirm"
|
|
|
|
result, decision = approval_system.request_approval(mock_high_risk_code)
|
|
|
|
assert result == ApprovalResult.APPROVED
|
|
assert decision.request.risk_analysis.risk_level == RiskLevel.HIGH
|
|
|
|
@patch("builtins.input")
|
|
def test_high_risk_approval_cancel(self, mock_input, approval_system, mock_high_risk_code):
|
|
"""Test high-risk approval with cancellation."""
|
|
mock_input.return_value = "cancel"
|
|
|
|
result, decision = approval_system.request_approval(mock_high_risk_code)
|
|
|
|
assert result == ApprovalResult.DENIED
|
|
|
|
@patch("builtins.print")
|
|
def test_blocked_operation(self, mock_print, approval_system, mock_blocked_code):
|
|
"""Test blocked operation handling."""
|
|
result, decision = approval_system.request_approval(mock_blocked_code)
|
|
|
|
assert result == ApprovalResult.BLOCKED
|
|
assert decision.request.risk_analysis.risk_level == RiskLevel.BLOCKED
|
|
|
|
def test_auto_approval_for_trusted_operation(self, approval_system, mock_low_risk_code):
|
|
"""Test auto-approval for trusted operations."""
|
|
# Set up user preference
|
|
approval_system.user_preferences["output_operation"] = "auto_allow"
|
|
|
|
result, decision = approval_system.request_approval(mock_low_risk_code)
|
|
|
|
assert result == ApprovalResult.ALLOWED
|
|
assert decision.user_input == "auto_allowed"
|
|
|
|
def test_approval_history(self, approval_system, mock_low_risk_code):
|
|
"""Test approval history tracking."""
|
|
# Add some decisions
|
|
with patch("builtins.input", return_value="y"):
|
|
approval_system.request_approval(mock_low_risk_code)
|
|
approval_system.request_approval(mock_low_risk_code)
|
|
|
|
history = approval_system.get_approval_history(5)
|
|
assert len(history) == 2
|
|
assert all(isinstance(decision, ApprovalDecision) for decision in history)
|
|
|
|
def test_trust_patterns_learning(self, approval_system, mock_low_risk_code):
|
|
"""Test trust pattern learning."""
|
|
# Add approved decisions
|
|
with patch("builtins.input", return_value="y"):
|
|
for _ in range(3):
|
|
approval_system.request_approval(mock_low_risk_code)
|
|
|
|
patterns = approval_system.get_trust_patterns()
|
|
assert "output_operation" in patterns
|
|
assert patterns["output_operation"] == 3
|
|
|
|
def test_preferences_reset(self, approval_system):
|
|
"""Test preferences reset."""
|
|
# Add some preferences
|
|
approval_system.user_preferences = {"test": "value"}
|
|
approval_system.reset_preferences()
|
|
|
|
assert approval_system.user_preferences == {}
|
|
|
|
def test_is_code_safe(self, approval_system, mock_low_risk_code, mock_high_risk_code):
|
|
"""Test quick safety check."""
|
|
assert approval_system.is_code_safe(mock_low_risk_code) == True
|
|
assert approval_system.is_code_safe(mock_high_risk_code) == False
|
|
|
|
def test_context_awareness(self, approval_system, mock_low_risk_code):
|
|
"""Test context-aware risk analysis."""
|
|
# New user context should increase risk
|
|
context_new_user = {"user_level": "new"}
|
|
risk_new = approval_system._analyze_code_risk(mock_low_risk_code, context_new_user)
|
|
|
|
context_known_user = {"user_level": "known"}
|
|
risk_known = approval_system._analyze_code_risk(mock_low_risk_code, context_known_user)
|
|
|
|
assert risk_new.severity_score > risk_known.severity_score
|
|
assert "New user profile" in risk_new.reasons
|
|
|
|
def test_request_id_uniqueness(self, approval_system):
|
|
"""Test that request IDs are unique even for same code."""
|
|
code = 'print("test")'
|
|
ids = []
|
|
|
|
for _ in range(10):
|
|
rid = approval_system._generate_request_id(code)
|
|
assert rid not in ids, f"Duplicate ID: {rid}"
|
|
ids.append(rid)
|
|
|
|
def test_risk_score_accumulation(self, approval_system):
|
|
"""Test that multiple risk factors accumulate."""
|
|
# Code with multiple risk factors
|
|
risky_code = """
|
|
import os
|
|
import subprocess
|
|
os.system("ls")
|
|
subprocess.call(["pwd"])
|
|
"""
|
|
risk_analysis = approval_system._analyze_code_risk(risky_code, {})
|
|
|
|
assert risk_analysis.severity_score > 0.5
|
|
assert len(risk_analysis.reasons) >= 2
|
|
assert "system_operations" in risk_analysis.affected_resources
|
|
|
|
@patch("builtins.input")
|
|
def test_detailed_view_presentation(self, mock_input, approval_system, mock_medium_risk_code):
|
|
"""Test detailed view presentation."""
|
|
mock_input.return_value = "y"
|
|
|
|
# Create a request
|
|
risk_analysis = approval_system._analyze_code_risk(mock_medium_risk_code, {})
|
|
request = ApprovalRequest(
|
|
code=mock_medium_risk_code,
|
|
risk_analysis=risk_analysis,
|
|
context={"test": "value"},
|
|
timestamp=datetime.now(),
|
|
request_id="test123",
|
|
)
|
|
|
|
result = approval_system._present_detailed_view(request)
|
|
assert result == "allowed"
|
|
|
|
@patch("builtins.input")
|
|
def test_detailed_analysis_presentation(self, mock_input, approval_system, mock_high_risk_code):
|
|
"""Test detailed analysis presentation."""
|
|
mock_input.return_value = "confirm"
|
|
|
|
# Create a request
|
|
risk_analysis = approval_system._analyze_code_risk(mock_high_risk_code, {})
|
|
request = ApprovalRequest(
|
|
code=mock_high_risk_code,
|
|
risk_analysis=risk_analysis,
|
|
context={},
|
|
timestamp=datetime.now(),
|
|
request_id="test456",
|
|
)
|
|
|
|
result = approval_system._present_detailed_analysis(request)
|
|
assert result == "allowed"
|
|
|
|
def test_error_handling_in_risk_analysis(self, approval_system):
|
|
"""Test error handling in risk analysis."""
|
|
# Test with None code (should not crash)
|
|
try:
|
|
risk_analysis = approval_system._analyze_code_risk(None, {})
|
|
# Should still return a valid RiskAnalysis object
|
|
assert isinstance(risk_analysis, RiskAnalysis)
|
|
except Exception:
|
|
# If it raises an exception, that's also acceptable behavior
|
|
pass
|
|
|
|
def test_preferences_persistence(self, approval_system):
|
|
"""Test preferences persistence simulation."""
|
|
# Simulate loading preferences with error
|
|
with patch.object(approval_system, "_load_preferences") as mock_load:
|
|
mock_load.side_effect = Exception("Load error")
|
|
|
|
# Should not crash during initialization
|
|
try:
|
|
approval_system._load_preferences()
|
|
except Exception:
|
|
pass # Expected
|
|
|
|
# Simulate saving preferences with error
|
|
with patch.object(approval_system, "_save_preferences") as mock_save:
|
|
mock_save.side_effect = Exception("Save error")
|
|
|
|
# Should not crash when saving
|
|
try:
|
|
approval_system._save_preferences()
|
|
except Exception:
|
|
pass # Expected
|
|
|
|
@pytest.mark.parametrize(
|
|
"code_pattern,expected_risk",
|
|
[
|
|
('print("hello")', RiskLevel.LOW),
|
|
("import os", RiskLevel.MEDIUM),
|
|
('os.system("ls")', RiskLevel.HIGH),
|
|
("rm -rf /", RiskLevel.BLOCKED),
|
|
('eval("x + 1")', RiskLevel.HIGH),
|
|
('exec("print(1)")', RiskLevel.HIGH),
|
|
('__import__("os")', RiskLevel.HIGH),
|
|
],
|
|
)
|
|
def test_risk_patterns(self, approval_system, code_pattern, expected_risk):
|
|
"""Test various code patterns for risk classification."""
|
|
risk_analysis = approval_system._analyze_code_risk(code_pattern, {})
|
|
|
|
# Allow some flexibility in risk assessment
|
|
if expected_risk == RiskLevel.HIGH:
|
|
assert risk_analysis.risk_level in [RiskLevel.HIGH, RiskLevel.BLOCKED]
|
|
else:
|
|
assert risk_analysis.risk_level == expected_risk
|
|
|
|
def test_approval_decision_dataclass(self):
|
|
"""Test ApprovalDecision dataclass."""
|
|
now = datetime.now()
|
|
request = ApprovalRequest(
|
|
code='print("test")',
|
|
risk_analysis=RiskAnalysis(
|
|
risk_level=RiskLevel.LOW,
|
|
confidence=0.8,
|
|
reasons=[],
|
|
affected_resources=[],
|
|
severity_score=0.1,
|
|
),
|
|
context={},
|
|
timestamp=now,
|
|
request_id="test123",
|
|
)
|
|
|
|
decision = ApprovalDecision(
|
|
request=request,
|
|
result=ApprovalResult.APPROVED,
|
|
user_input="y",
|
|
timestamp=now,
|
|
trust_updated=False,
|
|
)
|
|
|
|
assert decision.request == request
|
|
assert decision.result == ApprovalResult.APPROVED
|
|
assert decision.user_input == "y"
|
|
assert decision.timestamp == now
|
|
assert decision.trust_updated == False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|