""" Advanced training pipeline for Lyra with sliding context window and adaptive learning. Implements sophisticated training strategies including: - Sliding context window for long conversations - Dynamic curriculum based on Lyra's emotional and personality state - Memory consolidation and replay - Human-like learning patterns """ import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts import numpy as np import logging from pathlib import Path from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass from datetime import datetime import json import asyncio from collections import deque import random from ..config import config from ..core.lyra_model import LyraModel from ..database.manager import DatabaseManager from ..emotions.system import EmotionalState logger = logging.getLogger(__name__) @dataclass class TrainingBatch: """Represents a training batch with context.""" input_ids: torch.Tensor attention_mask: torch.Tensor target_ids: torch.Tensor emotional_context: torch.Tensor personality_context: torch.Tensor conversation_id: str turn_index: int metadata: Dict[str, Any] @dataclass class LearningMemory: """Represents a significant learning memory.""" conversation_embedding: torch.Tensor emotional_state: EmotionalState user_feedback: float learning_outcome: str timestamp: datetime replay_count: int = 0 class ConversationDataset(Dataset): """Dataset for conversation training with sliding windows.""" def __init__( self, conversations: List[Dict[str, Any]], tokenizer, max_length: int = 512, sliding_window: int = 256, overlap: int = 64 ): self.conversations = conversations self.tokenizer = tokenizer self.max_length = max_length self.sliding_window = sliding_window self.overlap = overlap self.samples = self._prepare_samples() def _prepare_samples(self) -> List[Dict[str, Any]]: """Prepare training samples with sliding windows.""" samples = [] for conv in self.conversations: # Extract conversation turns turns = conv.get('turns', []) full_text = "" # Build conversation context for i, turn in enumerate(turns): if turn['role'] == 'user': full_text += f"User: {turn['content']}\n" elif turn['role'] == 'assistant': full_text += f"Lyra: {turn['content']}\n" # Create sliding windows tokens = self.tokenizer.encode(full_text) for start_idx in range(0, len(tokens) - self.sliding_window, self.sliding_window - self.overlap): end_idx = min(start_idx + self.sliding_window, len(tokens)) window_tokens = tokens[start_idx:end_idx] if len(window_tokens) < 32: # Skip very short windows continue # Target is the next token sequence input_tokens = window_tokens[:-1] target_tokens = window_tokens[1:] samples.append({ 'input_ids': input_tokens, 'target_ids': target_tokens, 'conversation_id': conv.get('id', ''), 'emotional_context': conv.get('emotional_state', {}), 'personality_context': conv.get('personality_state', {}), 'metadata': conv.get('metadata', {}) }) return samples def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int) -> Dict[str, Any]: return self.samples[idx] class AdaptiveLearningScheduler: """Adaptive learning rate based on emotional and personality state.""" def __init__(self, base_lr: float = 1e-4): self.base_lr = base_lr self.emotional_multipliers = { 'joy': 1.2, # Learn faster when happy 'curiosity': 1.5, # Learn much faster when curious 'frustration': 0.7, # Learn slower when frustrated 'confusion': 0.5, # Learn slower when confused 'confidence': 1.1 # Learn slightly faster when confident } def get_learning_rate( self, emotional_state: EmotionalState, personality_openness: float, recent_performance: float ) -> float: """Calculate adaptive learning rate.""" # Base rate adjustment lr = self.base_lr # Emotional adjustment dominant_emotion, intensity = emotional_state.get_dominant_emotion() if dominant_emotion in self.emotional_multipliers: lr *= self.emotional_multipliers[dominant_emotion] * intensity # Personality adjustment (openness to experience) lr *= (1.0 + personality_openness * 0.3) # Performance adjustment if recent_performance > 0.8: lr *= 1.1 # Increase when performing well elif recent_performance < 0.4: lr *= 0.8 # Decrease when struggling return max(lr, self.base_lr * 0.1) # Don't go too low class LyraTrainingPipeline: """Complete training pipeline for Lyra with human-like learning patterns.""" def __init__( self, model: LyraModel, tokenizer, device: torch.device, database_manager: Optional[DatabaseManager] = None ): self.model = model self.tokenizer = tokenizer self.device = device self.database_manager = database_manager # Training components self.optimizer = AdamW(model.parameters(), lr=config.learning_rate) self.scheduler = CosineAnnealingWarmRestarts( self.optimizer, T_0=1000, eta_min=1e-6 ) self.adaptive_scheduler = AdaptiveLearningScheduler() # Memory systems self.learning_memories = deque(maxlen=1000) self.replay_buffer = deque(maxlen=5000) # Training state self.global_step = 0 self.epoch = 0 self.best_performance = 0.0 self.training_history = [] # Human-like learning patterns self.forgetting_curve = self._initialize_forgetting_curve() self.consolidation_schedule = self._create_consolidation_schedule() def _initialize_forgetting_curve(self) -> Dict[str, float]: """Initialize forgetting curve parameters.""" return { 'initial_strength': 1.0, 'decay_rate': 0.05, 'consolidation_boost': 1.3, 'interference_factor': 0.1 } def _create_consolidation_schedule(self) -> List[int]: """Create memory consolidation schedule (like sleep cycles).""" # Consolidate at increasing intervals: 1h, 6h, 24h, 72h, 168h return [100, 600, 2400, 7200, 16800] # In training steps async def train_epoch( self, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader] = None ) -> Dict[str, float]: """Train for one epoch with adaptive learning.""" self.model.train() epoch_loss = 0.0 num_batches = 0 emotional_adjustments = 0 for batch_idx, batch in enumerate(train_dataloader): # Move batch to device batch = self._prepare_batch(batch) # Get current emotional and personality state emotional_state = self._get_current_emotional_state() personality_state = self._get_current_personality_state() # Adaptive learning rate current_performance = self._calculate_recent_performance() adaptive_lr = self.adaptive_scheduler.get_learning_rate( emotional_state, personality_state.get('openness', 0.5), current_performance ) # Adjust optimizer learning rate if significantly different current_lr = self.optimizer.param_groups[0]['lr'] if abs(adaptive_lr - current_lr) > current_lr * 0.1: for param_group in self.optimizer.param_groups: param_group['lr'] = adaptive_lr emotional_adjustments += 1 # Forward pass self.optimizer.zero_grad() outputs, lyra_info = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], user_id=batch.get('user_id'), conversation_context=batch.get('context') ) # Calculate loss loss = self._calculate_adaptive_loss( outputs, batch['target_ids'], emotional_state ) # Backward pass loss.backward() # Gradient clipping (human-like learning stability) torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) # Optimizer step self.optimizer.step() self.scheduler.step() # Update training state epoch_loss += loss.item() num_batches += 1 self.global_step += 1 # Memory consolidation if self.global_step in self.consolidation_schedule: await self._consolidate_memories() # Experience replay (20% chance) if random.random() < 0.2 and len(self.replay_buffer) > 10: await self._experience_replay() # Log progress if batch_idx % 100 == 0: logger.info( f"Epoch {self.epoch}, Batch {batch_idx}, " f"Loss: {loss.item():.4f}, " f"LR: {adaptive_lr:.2e}, " f"Emotional adjustments: {emotional_adjustments}" ) # Validation val_metrics = {} if val_dataloader: val_metrics = await self._validate(val_dataloader) # Record training history epoch_metrics = { 'epoch': self.epoch, 'train_loss': epoch_loss / num_batches, 'learning_rate': self.optimizer.param_groups[0]['lr'], 'emotional_adjustments': emotional_adjustments, 'global_step': self.global_step, **val_metrics } self.training_history.append(epoch_metrics) self.epoch += 1 return epoch_metrics def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]: """Prepare batch for training.""" prepared = {} for key, value in batch.items(): if isinstance(value, torch.Tensor): prepared[key] = value.to(self.device) elif isinstance(value, list): # Convert list to tensor if numeric try: prepared[key] = torch.tensor(value).to(self.device) except: prepared[key] = value else: prepared[key] = value return prepared def _get_current_emotional_state(self) -> EmotionalState: """Get Lyra's current emotional state.""" # This would normally come from the emotional system # For now, create a default state emotions = torch.rand(19) # 19 emotion dimensions return EmotionalState.from_tensor(emotions, self.device) def _get_current_personality_state(self) -> Dict[str, float]: """Get current personality traits.""" return { 'openness': 0.7, 'conscientiousness': 0.8, 'extraversion': 0.6, 'agreeableness': 0.9, 'neuroticism': 0.3 } def _calculate_recent_performance(self) -> float: """Calculate recent performance score.""" if not self.training_history: return 0.5 recent_epochs = self.training_history[-5:] # Last 5 epochs if not recent_epochs: return 0.5 # Simple performance metric based on loss improvement losses = [epoch['train_loss'] for epoch in recent_epochs] if len(losses) < 2: return 0.5 improvement = (losses[0] - losses[-1]) / losses[0] return min(max(0.5 + improvement, 0.0), 1.0) def _calculate_adaptive_loss( self, outputs: torch.Tensor, targets: torch.Tensor, emotional_state: EmotionalState ) -> torch.Tensor: """Calculate loss adjusted for emotional state.""" # Base cross-entropy loss base_loss = nn.CrossEntropyLoss()( outputs.view(-1, outputs.size(-1)), targets.view(-1) ) # Emotional adjustment dominant_emotion, intensity = emotional_state.get_dominant_emotion() if dominant_emotion == 'frustration' and intensity > 0.7: # Reduce learning when frustrated (like humans) base_loss *= 0.8 elif dominant_emotion == 'curiosity' and intensity > 0.6: # Increase learning when curious base_loss *= 1.2 return base_loss async def _consolidate_memories(self): """Consolidate important memories (like sleep-based learning).""" if not self.learning_memories: return logger.info(f"Consolidating {len(self.learning_memories)} memories...") # Sort memories by importance (feedback score + recency) important_memories = sorted( self.learning_memories, key=lambda m: m.user_feedback * (1.0 - m.replay_count * 0.1), reverse=True )[:50] # Top 50 memories # Replay important memories for memory in important_memories[:10]: # Convert memory to training sample self.replay_buffer.append({ 'conversation_embedding': memory.conversation_embedding, 'emotional_state': memory.emotional_state, 'feedback': memory.user_feedback, 'outcome': memory.learning_outcome }) memory.replay_count += 1 logger.info("Memory consolidation complete") async def _experience_replay(self): """Replay past experiences for better learning.""" if len(self.replay_buffer) < 5: return # Sample random memories replay_samples = random.sample(list(self.replay_buffer), min(5, len(self.replay_buffer))) # Process replay samples (simplified) for sample in replay_samples: # This would normally involve re-training on the sample # For now, just log the replay logger.debug(f"Replaying memory with feedback: {sample['feedback']}") async def _validate(self, val_dataloader: DataLoader) -> Dict[str, float]: """Validate model performance.""" self.model.eval() total_loss = 0.0 num_batches = 0 with torch.no_grad(): for batch in val_dataloader: batch = self._prepare_batch(batch) outputs, _ = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'] ) loss = nn.CrossEntropyLoss()( outputs.view(-1, outputs.size(-1)), batch['target_ids'].view(-1) ) total_loss += loss.item() num_batches += 1 self.model.train() avg_val_loss = total_loss / num_batches if num_batches > 0 else 0.0 return { 'val_loss': avg_val_loss, 'perplexity': torch.exp(torch.tensor(avg_val_loss)).item() } async def save_checkpoint(self, filepath: Path, metadata: Optional[Dict] = None): """Save training checkpoint.""" checkpoint = { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'global_step': self.global_step, 'epoch': self.epoch, 'training_history': self.training_history, 'best_performance': self.best_performance, 'learning_memories': list(self.learning_memories), 'forgetting_curve': self.forgetting_curve, 'metadata': metadata or {} } filepath.parent.mkdir(parents=True, exist_ok=True) torch.save(checkpoint, filepath) logger.info(f"Checkpoint saved to {filepath}") async def load_checkpoint(self, filepath: Path): """Load training checkpoint.""" checkpoint = torch.load(filepath, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.global_step = checkpoint.get('global_step', 0) self.epoch = checkpoint.get('epoch', 0) self.training_history = checkpoint.get('training_history', []) self.best_performance = checkpoint.get('best_performance', 0.0) self.learning_memories = deque( checkpoint.get('learning_memories', []), maxlen=1000 ) self.forgetting_curve = checkpoint.get('forgetting_curve', self.forgetting_curve) logger.info(f"Checkpoint loaded from {filepath}") def add_learning_memory( self, conversation_embedding: torch.Tensor, emotional_state: EmotionalState, user_feedback: float, learning_outcome: str ): """Add a significant learning memory.""" memory = LearningMemory( conversation_embedding=conversation_embedding, emotional_state=emotional_state, user_feedback=user_feedback, learning_outcome=learning_outcome, timestamp=datetime.now() ) self.learning_memories.append(memory) def get_training_statistics(self) -> Dict[str, Any]: """Get comprehensive training statistics.""" if not self.training_history: return {'status': 'no_training_data'} recent_performance = self._calculate_recent_performance() return { 'global_step': self.global_step, 'current_epoch': self.epoch, 'total_epochs_trained': len(self.training_history), 'recent_performance': recent_performance, 'best_performance': self.best_performance, 'learning_memories_count': len(self.learning_memories), 'replay_buffer_size': len(self.replay_buffer), 'current_learning_rate': self.optimizer.param_groups[0]['lr'], 'last_consolidation': max( [step for step in self.consolidation_schedule if step <= self.global_step], default=0 ), 'training_history_summary': { 'best_train_loss': min(h['train_loss'] for h in self.training_history), 'latest_train_loss': self.training_history[-1]['train_loss'], 'average_emotional_adjustments': np.mean([ h['emotional_adjustments'] for h in self.training_history ]) } if self.training_history else {} } async def create_training_pipeline( model: LyraModel, tokenizer, device: torch.device, database_manager: Optional[DatabaseManager] = None ) -> LyraTrainingPipeline: """Create and initialize training pipeline.""" pipeline = LyraTrainingPipeline(model, tokenizer, device, database_manager) # Load existing checkpoint if available checkpoint_path = Path(config.models_dir) / "checkpoints" / "latest_training.pt" if checkpoint_path.exists(): try: await pipeline.load_checkpoint(checkpoint_path) logger.info("Loaded existing training checkpoint") except Exception as e: logger.warning(f"Could not load checkpoint: {e}") return pipeline