""" Chat agent for NOVA with persona support """ import torch from typing import Optional, List, Dict from .persona import Persona, PersonaLoader from .memory import ConversationMemory from nova_core import NovaTransformer from nova_tokenizer import NovaTokenizer class ChatAgent: """ Chat agent that combines NOVA model with persona and memory """ def __init__( self, model: NovaTransformer, tokenizer: NovaTokenizer, persona: Optional[Persona] = None, use_memory: bool = True, memory_db_path: Optional[str] = None, ): """ Args: model: NOVA transformer model tokenizer: NOVA tokenizer persona: Persona configuration (defaults to supportive girlfriend) use_memory: Whether to use conversation memory memory_db_path: Path to memory database """ self.model = model self.tokenizer = tokenizer self.persona = persona or PersonaLoader.create_girlfriend_supportive() # Conversation memory self.use_memory = use_memory if use_memory: self.memory = ConversationMemory(db_path=memory_db_path) else: self.memory = None # Current conversation context self.conversation_id = None self.context = [] def start_conversation(self, conversation_id: Optional[str] = None): """Start a new conversation""" if conversation_id and self.memory: # Load existing conversation self.conversation_id = conversation_id self.context = self.memory.load_conversation(conversation_id) else: # Start fresh import uuid self.conversation_id = conversation_id or str(uuid.uuid4()) self.context = [] # Add system prompt if configured system_prompt = self.persona.format_system_prompt() if system_prompt: self.context.append({ 'role': 'system', 'content': system_prompt }) def chat(self, message: str) -> str: """ Send a message and get response Args: message: User message Returns: NOVA's response """ # Add user message to context self.context.append({ 'role': 'user', 'content': message }) # Format prompt from conversation context prompt = self._format_prompt() # Get generation parameters from persona gen_params = self.persona.get_generation_params() # Generate response response = self._generate(prompt, **gen_params) # Add to context self.context.append({ 'role': 'assistant', 'content': response }) # Save to memory if self.memory: self.memory.add_message( conversation_id=self.conversation_id, role='user', content=message ) self.memory.add_message( conversation_id=self.conversation_id, role='assistant', content=response ) return response def _format_prompt(self) -> str: """Format conversation context into prompt string""" parts = [] for msg in self.context: role = msg['role'] content = msg['content'] if role == 'system': parts.append(f"{content}") elif role == 'user': parts.append(f"User: {content}") elif role == 'assistant': parts.append(f"{self.persona.name}: {content}") # Add prefix for assistant response parts.append(f"{self.persona.name}:") return "\n".join(parts) def _generate( self, prompt: str, temperature: float = 0.8, top_p: float = 0.9, top_k: Optional[int] = 50, repetition_penalty: float = 1.1, max_new_tokens: int = 200, ) -> str: """Generate response using model""" # Tokenize prompt input_ids = self.tokenizer.encode(prompt, add_bos=True, add_eos=False) input_ids = torch.tensor([input_ids], dtype=torch.long) # Move to model device device = next(self.model.parameters()).device input_ids = input_ids.to(device) # Generate with torch.no_grad(): output_ids = self.model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, eos_token_id=self.tokenizer.eos_id, ) # Decode response (skip the prompt part) response_ids = output_ids[0][input_ids.shape[1]:].tolist() response = self.tokenizer.decode(response_ids, skip_special_tokens=True) # Clean up response response = response.strip() # Remove any accidental continuation of prompt if response.startswith(f"{self.persona.name}:"): response = response[len(f"{self.persona.name}:"):].strip() return response def clear_context(self): """Clear conversation context (but keep system prompt)""" system_messages = [msg for msg in self.context if msg['role'] == 'system'] self.context = system_messages def get_context(self) -> List[Dict[str, str]]: """Get current conversation context""" return self.context.copy() def set_persona(self, persona: Persona): """Change persona mid-conversation""" self.persona = persona