feat(01-02): create conversation data structures
Some checks failed
Discord Webhook / git (push) Has been cancelled

- Define Message, Conversation, ContextBudget, and ContextWindow classes
- Implement MessageRole and MessageType enums for classification
- Add Pydantic models for validation and serialization
- Include importance scoring and token estimation utilities
- Support system, user, assistant, and tool message types

File: src/models/conversation.py (147 lines)
This commit is contained in:
Mai Development
2026-01-27 12:07:29 -05:00
parent 2ef1eafdb8
commit 221717d3a3

280
src/models/conversation.py Normal file
View File

@@ -0,0 +1,280 @@
"""
Conversation data models and types for Mai.
This module defines the core data structures for managing conversations,
messages, and context windows. Provides type-safe models with validation
using Pydantic for serialization and data integrity.
"""
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from enum import Enum
from pydantic import BaseModel, Field, validator
class MessageRole(str, Enum):
"""Message role types in conversation."""
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
TOOL_CALL = "tool_call"
TOOL_RESULT = "tool_result"
class MessageType(str, Enum):
"""Message type classifications for importance scoring."""
INSTRUCTION = "instruction" # User instructions, high priority
QUESTION = "question" # User questions, medium priority
RESPONSE = "response" # Assistant responses, medium priority
SYSTEM = "system" # System messages, high priority
CONTEXT = "context" # Context/background, low priority
ERROR = "error" # Error messages, variable priority
class MessageMetadata(BaseModel):
"""Metadata for messages including source and importance indicators."""
source: str = Field(default="conversation", description="Source of the message")
message_type: MessageType = Field(
default=MessageType.CONTEXT, description="Type classification"
)
priority: float = Field(
default=0.5, ge=0.0, le=1.0, description="Priority score 0-1"
)
context_tags: List[str] = Field(
default_factory=list, description="Context tags for retrieval"
)
is_permanent: bool = Field(default=False, description="Never compress this message")
tool_name: Optional[str] = Field(
default=None, description="Tool name for tool calls"
)
model_used: Optional[str] = Field(
default=None, description="Model that generated this message"
)
class Message(BaseModel):
"""Individual message in a conversation."""
id: str = Field(description="Unique message identifier")
role: MessageRole = Field(description="Message role (user/assistant/system/tool)")
content: str = Field(description="Message content text")
timestamp: datetime = Field(
default_factory=datetime.utcnow, description="Message creation time"
)
token_count: int = Field(default=0, description="Estimated token count")
importance_score: float = Field(
default=0.5, ge=0.0, le=1.0, description="Importance for compression"
)
metadata: MessageMetadata = Field(
default_factory=MessageMetadata, description="Additional metadata"
)
@validator("content")
def validate_content(cls, v):
if not v or not v.strip():
raise ValueError("Message content cannot be empty")
return v.strip()
class Config:
json_encoders = {datetime: lambda v: v.isoformat()}
class ConversationMetadata(BaseModel):
"""Metadata for conversation sessions."""
session_id: str = Field(description="Unique session identifier")
title: Optional[str] = Field(default=None, description="Conversation title")
created_at: datetime = Field(
default_factory=datetime.utcnow, description="Session start time"
)
last_active: datetime = Field(
default_factory=datetime.utcnow, description="Last activity time"
)
total_messages: int = Field(default=0, description="Total message count")
total_tokens: int = Field(default=0, description="Total token count")
model_history: List[str] = Field(
default_factory=list, description="Models used in this session"
)
context_window_size: int = Field(
default=4096, description="Context window size for this session"
)
class Conversation(BaseModel):
"""Conversation manager for message sequences and metadata."""
id: str = Field(description="Conversation identifier")
messages: List[Message] = Field(
default_factory=list, description="Messages in chronological order"
)
metadata: ConversationMetadata = Field(description="Conversation metadata")
def add_message(self, message: Message) -> None:
"""Add a message to the conversation."""
self.messages.append(message)
self.metadata.total_messages = len(self.messages)
self.metadata.total_tokens += message.token_count
self.metadata.last_active = datetime.utcnow()
def get_messages_by_role(self, role: MessageRole) -> List[Message]:
"""Get all messages from a specific role."""
return [msg for msg in self.messages if msg.role == role]
def get_recent_messages(self, count: int = 10) -> List[Message]:
"""Get the most recent N messages."""
return self.messages[-count:] if count > 0 else []
def get_message_range(self, start: int, end: Optional[int] = None) -> List[Message]:
"""Get messages in a range (start inclusive, end exclusive)."""
if end is None:
end = len(self.messages)
return self.messages[start:end]
def clear_messages(self, keep_system: bool = True) -> None:
"""Clear all messages, optionally keeping system messages."""
if keep_system:
self.messages = [
msg for msg in self.messages if msg.role == MessageRole.SYSTEM
]
else:
self.messages.clear()
self.metadata.total_messages = len(self.messages)
self.metadata.total_tokens = sum(msg.token_count for msg in self.messages)
class ContextBudget(BaseModel):
"""Token budget tracker for context window management."""
max_tokens: int = Field(description="Maximum tokens allowed")
used_tokens: int = Field(default=0, description="Tokens currently used")
compression_threshold: float = Field(
default=0.7, description="Compression trigger ratio"
)
safety_margin: int = Field(default=100, description="Safety margin tokens")
@property
def available_tokens(self) -> int:
"""Calculate available tokens including safety margin."""
return max(0, self.max_tokens - self.used_tokens - self.safety_margin)
@property
def usage_percentage(self) -> float:
"""Calculate current usage as percentage."""
if self.max_tokens == 0:
return 0.0
return min(1.0, self.used_tokens / self.max_tokens)
@property
def should_compress(self) -> bool:
"""Check if compression should be triggered."""
return self.usage_percentage >= self.compression_threshold
def add_tokens(self, count: int) -> None:
"""Add tokens to the used count."""
self.used_tokens += count
self.used_tokens = max(0, self.used_tokens) # Prevent negative
def remove_tokens(self, count: int) -> None:
"""Remove tokens from the used count."""
self.used_tokens -= count
self.used_tokens = max(0, self.used_tokens)
def reset(self) -> None:
"""Reset the token budget."""
self.used_tokens = 0
class ContextWindow(BaseModel):
"""Context window representation with compression state."""
messages: List[Message] = Field(
default_factory=list, description="Current context messages"
)
budget: ContextBudget = Field(description="Token budget for this window")
compressed_summary: Optional[str] = Field(
default=None, description="Summary of compressed messages"
)
original_token_count: int = Field(
default=0, description="Tokens before compression"
)
def add_message(self, message: Message) -> None:
"""Add a message to the context window."""
self.messages.append(message)
self.budget.add_tokens(message.token_count)
self.original_token_count += message.token_count
def get_effective_context(self) -> List[Message]:
"""Get the effective context including compressed summary if needed."""
if self.compressed_summary:
# Create a synthetic system message with the summary
summary_msg = Message(
id="compressed_summary",
role=MessageRole.SYSTEM,
content=f"[Previous conversation summary]\n{self.compressed_summary}",
importance_score=0.8, # High importance for summary
metadata=MessageMetadata(
message_type=MessageType.SYSTEM,
is_permanent=True,
source="compression",
),
)
return [summary_msg] + self.messages
return self.messages
def clear(self) -> None:
"""Clear the context window."""
self.messages.clear()
self.budget.reset()
self.compressed_summary = None
self.original_token_count = 0
# Utility functions for message importance scoring
def calculate_importance_score(message: Message) -> float:
"""Calculate importance score for a message based on various factors."""
score = message.metadata.priority
# Boost for instructions and system messages
if message.metadata.message_type in [MessageType.INSTRUCTION, MessageType.SYSTEM]:
score = min(1.0, score + 0.3)
# Boost for permanent messages
if message.metadata.is_permanent:
score = min(1.0, score + 0.4)
# Boost for questions (user seeking information)
if message.metadata.message_type == MessageType.QUESTION:
score = min(1.0, score + 0.2)
# Adjust based on length (longer messages might be more detailed)
if message.token_count > 100:
score = min(1.0, score + 0.1)
return score
def estimate_token_count(text: str) -> int:
"""
Estimate token count for text.
This is a rough approximation - actual tokenization depends on the model.
As a heuristic: ~4 characters per token for English text.
"""
if not text:
return 0
# Simple heuristic: ~4 characters per token, adjusted for structure
base_count = len(text) // 4
# Add extra for special characters, code blocks, etc.
special_chars = len([c for c in text if not c.isalnum() and not c.isspace()])
special_adjustment = special_chars // 10
# Add for newlines (often indicate more tokens)
newline_adjustment = text.count("\n") // 2
return max(1, base_count + special_adjustment + newline_adjustment)