feat(01-02): create conversation data structures
Some checks failed
Discord Webhook / git (push) Has been cancelled
Some checks failed
Discord Webhook / git (push) Has been cancelled
- Define Message, Conversation, ContextBudget, and ContextWindow classes - Implement MessageRole and MessageType enums for classification - Add Pydantic models for validation and serialization - Include importance scoring and token estimation utilities - Support system, user, assistant, and tool message types File: src/models/conversation.py (147 lines)
This commit is contained in:
280
src/models/conversation.py
Normal file
280
src/models/conversation.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
"""
|
||||||
|
Conversation data models and types for Mai.
|
||||||
|
|
||||||
|
This module defines the core data structures for managing conversations,
|
||||||
|
messages, and context windows. Provides type-safe models with validation
|
||||||
|
using Pydantic for serialization and data integrity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRole(str, Enum):
|
||||||
|
"""Message role types in conversation."""
|
||||||
|
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
SYSTEM = "system"
|
||||||
|
TOOL_CALL = "tool_call"
|
||||||
|
TOOL_RESULT = "tool_result"
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(str, Enum):
|
||||||
|
"""Message type classifications for importance scoring."""
|
||||||
|
|
||||||
|
INSTRUCTION = "instruction" # User instructions, high priority
|
||||||
|
QUESTION = "question" # User questions, medium priority
|
||||||
|
RESPONSE = "response" # Assistant responses, medium priority
|
||||||
|
SYSTEM = "system" # System messages, high priority
|
||||||
|
CONTEXT = "context" # Context/background, low priority
|
||||||
|
ERROR = "error" # Error messages, variable priority
|
||||||
|
|
||||||
|
|
||||||
|
class MessageMetadata(BaseModel):
|
||||||
|
"""Metadata for messages including source and importance indicators."""
|
||||||
|
|
||||||
|
source: str = Field(default="conversation", description="Source of the message")
|
||||||
|
message_type: MessageType = Field(
|
||||||
|
default=MessageType.CONTEXT, description="Type classification"
|
||||||
|
)
|
||||||
|
priority: float = Field(
|
||||||
|
default=0.5, ge=0.0, le=1.0, description="Priority score 0-1"
|
||||||
|
)
|
||||||
|
context_tags: List[str] = Field(
|
||||||
|
default_factory=list, description="Context tags for retrieval"
|
||||||
|
)
|
||||||
|
is_permanent: bool = Field(default=False, description="Never compress this message")
|
||||||
|
tool_name: Optional[str] = Field(
|
||||||
|
default=None, description="Tool name for tool calls"
|
||||||
|
)
|
||||||
|
model_used: Optional[str] = Field(
|
||||||
|
default=None, description="Model that generated this message"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
"""Individual message in a conversation."""
|
||||||
|
|
||||||
|
id: str = Field(description="Unique message identifier")
|
||||||
|
role: MessageRole = Field(description="Message role (user/assistant/system/tool)")
|
||||||
|
content: str = Field(description="Message content text")
|
||||||
|
timestamp: datetime = Field(
|
||||||
|
default_factory=datetime.utcnow, description="Message creation time"
|
||||||
|
)
|
||||||
|
token_count: int = Field(default=0, description="Estimated token count")
|
||||||
|
importance_score: float = Field(
|
||||||
|
default=0.5, ge=0.0, le=1.0, description="Importance for compression"
|
||||||
|
)
|
||||||
|
metadata: MessageMetadata = Field(
|
||||||
|
default_factory=MessageMetadata, description="Additional metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
@validator("content")
|
||||||
|
def validate_content(cls, v):
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("Message content cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationMetadata(BaseModel):
|
||||||
|
"""Metadata for conversation sessions."""
|
||||||
|
|
||||||
|
session_id: str = Field(description="Unique session identifier")
|
||||||
|
title: Optional[str] = Field(default=None, description="Conversation title")
|
||||||
|
created_at: datetime = Field(
|
||||||
|
default_factory=datetime.utcnow, description="Session start time"
|
||||||
|
)
|
||||||
|
last_active: datetime = Field(
|
||||||
|
default_factory=datetime.utcnow, description="Last activity time"
|
||||||
|
)
|
||||||
|
total_messages: int = Field(default=0, description="Total message count")
|
||||||
|
total_tokens: int = Field(default=0, description="Total token count")
|
||||||
|
model_history: List[str] = Field(
|
||||||
|
default_factory=list, description="Models used in this session"
|
||||||
|
)
|
||||||
|
context_window_size: int = Field(
|
||||||
|
default=4096, description="Context window size for this session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Conversation(BaseModel):
|
||||||
|
"""Conversation manager for message sequences and metadata."""
|
||||||
|
|
||||||
|
id: str = Field(description="Conversation identifier")
|
||||||
|
messages: List[Message] = Field(
|
||||||
|
default_factory=list, description="Messages in chronological order"
|
||||||
|
)
|
||||||
|
metadata: ConversationMetadata = Field(description="Conversation metadata")
|
||||||
|
|
||||||
|
def add_message(self, message: Message) -> None:
|
||||||
|
"""Add a message to the conversation."""
|
||||||
|
self.messages.append(message)
|
||||||
|
self.metadata.total_messages = len(self.messages)
|
||||||
|
self.metadata.total_tokens += message.token_count
|
||||||
|
self.metadata.last_active = datetime.utcnow()
|
||||||
|
|
||||||
|
def get_messages_by_role(self, role: MessageRole) -> List[Message]:
|
||||||
|
"""Get all messages from a specific role."""
|
||||||
|
return [msg for msg in self.messages if msg.role == role]
|
||||||
|
|
||||||
|
def get_recent_messages(self, count: int = 10) -> List[Message]:
|
||||||
|
"""Get the most recent N messages."""
|
||||||
|
return self.messages[-count:] if count > 0 else []
|
||||||
|
|
||||||
|
def get_message_range(self, start: int, end: Optional[int] = None) -> List[Message]:
|
||||||
|
"""Get messages in a range (start inclusive, end exclusive)."""
|
||||||
|
if end is None:
|
||||||
|
end = len(self.messages)
|
||||||
|
return self.messages[start:end]
|
||||||
|
|
||||||
|
def clear_messages(self, keep_system: bool = True) -> None:
|
||||||
|
"""Clear all messages, optionally keeping system messages."""
|
||||||
|
if keep_system:
|
||||||
|
self.messages = [
|
||||||
|
msg for msg in self.messages if msg.role == MessageRole.SYSTEM
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.messages.clear()
|
||||||
|
self.metadata.total_messages = len(self.messages)
|
||||||
|
self.metadata.total_tokens = sum(msg.token_count for msg in self.messages)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextBudget(BaseModel):
|
||||||
|
"""Token budget tracker for context window management."""
|
||||||
|
|
||||||
|
max_tokens: int = Field(description="Maximum tokens allowed")
|
||||||
|
used_tokens: int = Field(default=0, description="Tokens currently used")
|
||||||
|
compression_threshold: float = Field(
|
||||||
|
default=0.7, description="Compression trigger ratio"
|
||||||
|
)
|
||||||
|
safety_margin: int = Field(default=100, description="Safety margin tokens")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_tokens(self) -> int:
|
||||||
|
"""Calculate available tokens including safety margin."""
|
||||||
|
return max(0, self.max_tokens - self.used_tokens - self.safety_margin)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def usage_percentage(self) -> float:
|
||||||
|
"""Calculate current usage as percentage."""
|
||||||
|
if self.max_tokens == 0:
|
||||||
|
return 0.0
|
||||||
|
return min(1.0, self.used_tokens / self.max_tokens)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def should_compress(self) -> bool:
|
||||||
|
"""Check if compression should be triggered."""
|
||||||
|
return self.usage_percentage >= self.compression_threshold
|
||||||
|
|
||||||
|
def add_tokens(self, count: int) -> None:
|
||||||
|
"""Add tokens to the used count."""
|
||||||
|
self.used_tokens += count
|
||||||
|
self.used_tokens = max(0, self.used_tokens) # Prevent negative
|
||||||
|
|
||||||
|
def remove_tokens(self, count: int) -> None:
|
||||||
|
"""Remove tokens from the used count."""
|
||||||
|
self.used_tokens -= count
|
||||||
|
self.used_tokens = max(0, self.used_tokens)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the token budget."""
|
||||||
|
self.used_tokens = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ContextWindow(BaseModel):
|
||||||
|
"""Context window representation with compression state."""
|
||||||
|
|
||||||
|
messages: List[Message] = Field(
|
||||||
|
default_factory=list, description="Current context messages"
|
||||||
|
)
|
||||||
|
budget: ContextBudget = Field(description="Token budget for this window")
|
||||||
|
compressed_summary: Optional[str] = Field(
|
||||||
|
default=None, description="Summary of compressed messages"
|
||||||
|
)
|
||||||
|
original_token_count: int = Field(
|
||||||
|
default=0, description="Tokens before compression"
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_message(self, message: Message) -> None:
|
||||||
|
"""Add a message to the context window."""
|
||||||
|
self.messages.append(message)
|
||||||
|
self.budget.add_tokens(message.token_count)
|
||||||
|
self.original_token_count += message.token_count
|
||||||
|
|
||||||
|
def get_effective_context(self) -> List[Message]:
|
||||||
|
"""Get the effective context including compressed summary if needed."""
|
||||||
|
if self.compressed_summary:
|
||||||
|
# Create a synthetic system message with the summary
|
||||||
|
summary_msg = Message(
|
||||||
|
id="compressed_summary",
|
||||||
|
role=MessageRole.SYSTEM,
|
||||||
|
content=f"[Previous conversation summary]\n{self.compressed_summary}",
|
||||||
|
importance_score=0.8, # High importance for summary
|
||||||
|
metadata=MessageMetadata(
|
||||||
|
message_type=MessageType.SYSTEM,
|
||||||
|
is_permanent=True,
|
||||||
|
source="compression",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return [summary_msg] + self.messages
|
||||||
|
return self.messages
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear the context window."""
|
||||||
|
self.messages.clear()
|
||||||
|
self.budget.reset()
|
||||||
|
self.compressed_summary = None
|
||||||
|
self.original_token_count = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Utility functions for message importance scoring
|
||||||
|
def calculate_importance_score(message: Message) -> float:
|
||||||
|
"""Calculate importance score for a message based on various factors."""
|
||||||
|
score = message.metadata.priority
|
||||||
|
|
||||||
|
# Boost for instructions and system messages
|
||||||
|
if message.metadata.message_type in [MessageType.INSTRUCTION, MessageType.SYSTEM]:
|
||||||
|
score = min(1.0, score + 0.3)
|
||||||
|
|
||||||
|
# Boost for permanent messages
|
||||||
|
if message.metadata.is_permanent:
|
||||||
|
score = min(1.0, score + 0.4)
|
||||||
|
|
||||||
|
# Boost for questions (user seeking information)
|
||||||
|
if message.metadata.message_type == MessageType.QUESTION:
|
||||||
|
score = min(1.0, score + 0.2)
|
||||||
|
|
||||||
|
# Adjust based on length (longer messages might be more detailed)
|
||||||
|
if message.token_count > 100:
|
||||||
|
score = min(1.0, score + 0.1)
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_token_count(text: str) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count for text.
|
||||||
|
|
||||||
|
This is a rough approximation - actual tokenization depends on the model.
|
||||||
|
As a heuristic: ~4 characters per token for English text.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Simple heuristic: ~4 characters per token, adjusted for structure
|
||||||
|
base_count = len(text) // 4
|
||||||
|
|
||||||
|
# Add extra for special characters, code blocks, etc.
|
||||||
|
special_chars = len([c for c in text if not c.isalnum() and not c.isspace()])
|
||||||
|
special_adjustment = special_chars // 10
|
||||||
|
|
||||||
|
# Add for newlines (often indicate more tokens)
|
||||||
|
newline_adjustment = text.count("\n") // 2
|
||||||
|
|
||||||
|
return max(1, base_count + special_adjustment + newline_adjustment)
|
||||||
Reference in New Issue
Block a user