diff --git a/commands/moments.py b/commands/moments.py index 3fbe433..216df2f 100644 --- a/commands/moments.py +++ b/commands/moments.py @@ -1,33 +1,15 @@ import discord from discord import app_commands from typing import Optional, List -from datetime import datetime +from datetime import datetime, timedelta +import sqlite3 import logging from utils.database import Database +from discord.app_commands import MissingPermissions db = Database() -import discord -from discord import app_commands -from typing import Optional, List -from datetime import datetime -import logging -import uuid -from utils.database import Database - -db = Database() - -import discord -from discord import app_commands -from typing import Optional, List -from datetime import datetime -import logging -import uuid -from utils.database import Database - -db = Database() - class IncidentModal(discord.ui.Modal): def __init__(self): super().__init__(title="Log Incident") @@ -73,10 +55,10 @@ class IncidentModal(discord.ui.Modal): start_time = datetime.strptime(self.start_time.value, "%Y-%m-%d %H:%M") end_time = datetime.strptime(self.end_time.value, "%Y-%m-%d %H:%M") - + if start_time >= end_time: raise ValueError("End time must be after start time") - + if (end_time - start_time).total_seconds() > 86400: raise ValueError("Maximum timeframe duration is 24 hours") @@ -86,19 +68,19 @@ class IncidentModal(discord.ui.Modal): before=end_time ): messages.append(msg) - + messages = messages[::-1] # Oldest first capture_mode = "timeframe" capture_param = f"{start_time.strftime('%Y-%m-%d %H:%M')} to {end_time.strftime('%Y-%m-%d %H:%M')}" - + else: if not self.message_count.value: raise ValueError("Please provide either message count or timeframe") - + count = int(self.message_count.value) if not 1 <= count <= 50: raise ValueError("Message count must be between 1-50") - + messages = [ msg async for msg in interaction.channel.history(limit=count) ][::-1] @@ -136,7 +118,7 @@ class IncidentModal(discord.ui.Modal): color=0x00ff00 ) embed.add_field(name="Reason", value=self.reason.value[:500], inline=False) - + if messages: preview = f"{messages[0].content[:100]}..." if len(messages[0].content) > 100 else messages[0].content embed.add_field(name="First Message", value=preview, inline=False) @@ -155,6 +137,7 @@ class IncidentModal(discord.ui.Modal): ephemeral=True ) + class FollowupModal(discord.ui.Modal): def __init__(self, incident_id: str): super().__init__(title=f"Follow-up: {incident_id}") @@ -173,7 +156,7 @@ class FollowupModal(discord.ui.Modal): moderator_id=interaction.user.id, notes=self.notes.value ) - + if success: await interaction.response.send_message( f"✅ Follow-up added to **{self.incident_id}**", @@ -191,6 +174,7 @@ class FollowupModal(discord.ui.Modal): ephemeral=True ) + async def setup(client): # Context menu command @client.tree.context_menu(name="Mark as Funny Moment") @@ -259,13 +243,14 @@ async def setup(client): await interaction.response.send_message("❌ Incident not found", ephemeral=True) return - # Format messages + # Format messages with proper timestamps messages = "\n\n".join( - f"**** <@{msg['author_id']}>:\n{msg['content']}" + f"**** <@{msg['author_id']}>:\n" + f"{msg['content']}" for msg in incident['messages'] ) - # Format follow-ups + # Get follow-ups with proper timestamps followups = db.get_followups(incident_id) followup_text = "\n\n".join( f"**** <@{f['moderator_id']}>:\n{f['notes'][:200]}" @@ -283,8 +268,18 @@ async def setup(client): ), color=0xff0000 ) - embed.add_field(name="Messages", value=messages[:1020] + "..." if len(messages) > 1024 else messages, inline=False) - embed.add_field(name=f"Follow-ups ({len(followups)})", value=followup_text[:1020] + "..." if len(followup_text) > 1024 else followup_text, inline=False) + + embed.add_field( + name="Messages", + value=messages[:1020] + "..." if len(messages) > 1024 else messages, + inline=False + ) + + embed.add_field( + name=f"Follow-ups ({len(followups)})", + value=followup_text[:1020] + "..." if len(followup_text) > 1024 else followup_text, + inline=False + ) await interaction.response.send_message(embed=embed, ephemeral=True) @@ -361,4 +356,71 @@ async def setup(client): choices.append(app_commands.Choice(name=name, value=inc['id'])) return choices[:25] - client.tree.add_command(moments_group) \ No newline at end of file + # Global error handler for commands + @client.tree.error + async def on_command_error(interaction: discord.Interaction, error): + """Handle unauthorized access attempts""" + if isinstance(error, MissingPermissions): + # Log unauthorized access + db.log_unauthorized_access( + user_id=interaction.user.id, + command_used=interaction.command.name if interaction.command else "unknown", + details=f"Attempted params: {interaction.data}" + ) + + await interaction.response.send_message( + "⛔ You don't have permission to use this command.", + ephemeral=True + ) + else: + logging.error(f"Command error: {error}", exc_info=True) + await interaction.response.send_message( + "⚠️ An unexpected error occurred. This has been logged.", + ephemeral=True + ) + + @moments_group.command( + name="audit", + description="View unauthorized access attempts (Admin only)" + ) + @app_commands.describe(days="Lookback period in days (max 30)") + @app_commands.checks.has_permissions(administrator=True) + async def view_audit_log(interaction: discord.Interaction, days: int = 7): + try: + if days > 30 or days < 1: + raise ValueError("Lookback period must be 1-30 days") + + cutoff = datetime.now() - timedelta(days=days) + + with db._get_connection() as conn: + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + cursor.execute(""" + SELECT * FROM unauthorized_access + WHERE timestamp > ? + ORDER BY timestamp DESC + """, (cutoff.isoformat(),)) # Store and compare ISO format + + results = [dict(row) for row in cursor.fetchall()] + + if not results: + await interaction.response.send_message("✅ No unauthorized access attempts found", ephemeral=True) + return + + log_text = "\n".join( + f" | " + f"<@{row['user_id']}> tried `{row['command_used']}`" + for row in results + ) + + embed = discord.Embed( + title=f"Unauthorized Access Logs ({days} days)", + description=log_text[:4000], + color=0xff0000 + ) + await interaction.response.send_message(embed=embed, ephemeral=True) + + except Exception as e: + await interaction.response.send_message(f"❌ Error: {str(e)}", ephemeral=True) + + client.tree.add_command(moments_group) diff --git a/utils/database.py b/utils/database.py index 08710a0..43b332b 100644 --- a/utils/database.py +++ b/utils/database.py @@ -4,11 +4,6 @@ from datetime import datetime from typing import List, Dict, Optional -import sqlite3 -import logging -from datetime import datetime -from typing import List, Dict, Optional - class Database: def __init__(self, db_path: str = "data/moments.db"): self.db_path = db_path @@ -21,6 +16,16 @@ class Database: conn.execute("DROP TABLE IF EXISTS incidents") conn.execute("DROP TABLE IF EXISTS incident_messages") + conn.execute(""" + CREATE TABLE IF NOT EXISTS unauthorized_access ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + command_used TEXT NOT NULL, + timestamp DATETIME NOT NULL, + details TEXT + ) + """) + conn.execute(""" CREATE TABLE IF NOT EXISTS funny_moments ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -131,16 +136,17 @@ class Database: conn.row_factory = sqlite3.Row cursor = conn.cursor() - # Get incident details and parse timestamp + # Get incident details cursor.execute("SELECT * FROM incidents WHERE id = ?", (incident_id,)) incident = cursor.fetchone() if not incident: return None + # Convert timestamp string to datetime object incident_details = dict(incident) - incident_details['timestamp'] = datetime.fromisoformat(incident_details['timestamp']) # Convert string to datetime + incident_details['timestamp'] = datetime.fromisoformat(incident_details['timestamp']) - # Get messages with parsed timestamps + # Get related messages cursor.execute("SELECT * FROM incident_messages WHERE incident_id = ?", (incident_id,)) messages = [ {**dict(msg), 'timestamp': datetime.fromisoformat(msg['timestamp'])} @@ -180,7 +186,7 @@ class Database: return False def get_followups(self, incident_id: str) -> List[Dict]: - """Get follow-ups with proper timestamps""" + """Retrieve follow-ups with proper timestamps""" with self._get_connection() as conn: conn.row_factory = sqlite3.Row cursor = conn.cursor() @@ -189,3 +195,18 @@ class Database: {**dict(row), 'timestamp': datetime.fromisoformat(row['timestamp'])} for row in cursor.fetchall() ] + + def log_unauthorized_access(self, user_id: int, command_used: str, details: str = ""): + """Log unauthorized command attempts""" + try: + with self._get_connection() as conn: + conn.execute(""" + INSERT INTO unauthorized_access + (user_id, command_used, timestamp, details) + VALUES (?, ?, ?, ?) + """, (user_id, command_used, datetime.now(), details)) + conn.commit() + return True + except Exception as e: + logging.error(f"Failed to log unauthorized access: {e}") + return False