Added an automatic Audit for commands when someone attempts to access them when they shouldn't be.
This commit is contained in:
		| @@ -1,33 +1,15 @@ | |||||||
| import discord | import discord | ||||||
| from discord import app_commands | from discord import app_commands | ||||||
| from typing import Optional, List | from typing import Optional, List | ||||||
| from datetime import datetime | from datetime import datetime, timedelta | ||||||
|  | import sqlite3 | ||||||
| import logging | import logging | ||||||
| from utils.database import Database | from utils.database import Database | ||||||
|  | from discord.app_commands import MissingPermissions | ||||||
|  |  | ||||||
| db = Database() | db = Database() | ||||||
|  |  | ||||||
|  |  | ||||||
| import discord |  | ||||||
| from discord import app_commands |  | ||||||
| from typing import Optional, List |  | ||||||
| from datetime import datetime |  | ||||||
| import logging |  | ||||||
| import uuid |  | ||||||
| from utils.database import Database |  | ||||||
|  |  | ||||||
| db = Database() |  | ||||||
|  |  | ||||||
| import discord |  | ||||||
| from discord import app_commands |  | ||||||
| from typing import Optional, List |  | ||||||
| from datetime import datetime |  | ||||||
| import logging |  | ||||||
| import uuid |  | ||||||
| from utils.database import Database |  | ||||||
|  |  | ||||||
| db = Database() |  | ||||||
|  |  | ||||||
| class IncidentModal(discord.ui.Modal): | class IncidentModal(discord.ui.Modal): | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         super().__init__(title="Log Incident") |         super().__init__(title="Log Incident") | ||||||
| @@ -73,10 +55,10 @@ class IncidentModal(discord.ui.Modal): | |||||||
|  |  | ||||||
|                 start_time = datetime.strptime(self.start_time.value, "%Y-%m-%d %H:%M") |                 start_time = datetime.strptime(self.start_time.value, "%Y-%m-%d %H:%M") | ||||||
|                 end_time = datetime.strptime(self.end_time.value, "%Y-%m-%d %H:%M") |                 end_time = datetime.strptime(self.end_time.value, "%Y-%m-%d %H:%M") | ||||||
|                  |  | ||||||
|                 if start_time >= end_time: |                 if start_time >= end_time: | ||||||
|                     raise ValueError("End time must be after start time") |                     raise ValueError("End time must be after start time") | ||||||
|                  |  | ||||||
|                 if (end_time - start_time).total_seconds() > 86400: |                 if (end_time - start_time).total_seconds() > 86400: | ||||||
|                     raise ValueError("Maximum timeframe duration is 24 hours") |                     raise ValueError("Maximum timeframe duration is 24 hours") | ||||||
|  |  | ||||||
| @@ -86,19 +68,19 @@ class IncidentModal(discord.ui.Modal): | |||||||
|                     before=end_time |                     before=end_time | ||||||
|                 ): |                 ): | ||||||
|                     messages.append(msg) |                     messages.append(msg) | ||||||
|                  |  | ||||||
|                 messages = messages[::-1]  # Oldest first |                 messages = messages[::-1]  # Oldest first | ||||||
|                 capture_mode = "timeframe" |                 capture_mode = "timeframe" | ||||||
|                 capture_param = f"{start_time.strftime('%Y-%m-%d %H:%M')} to {end_time.strftime('%Y-%m-%d %H:%M')}" |                 capture_param = f"{start_time.strftime('%Y-%m-%d %H:%M')} to {end_time.strftime('%Y-%m-%d %H:%M')}" | ||||||
|                  |  | ||||||
|             else: |             else: | ||||||
|                 if not self.message_count.value: |                 if not self.message_count.value: | ||||||
|                     raise ValueError("Please provide either message count or timeframe") |                     raise ValueError("Please provide either message count or timeframe") | ||||||
|                  |  | ||||||
|                 count = int(self.message_count.value) |                 count = int(self.message_count.value) | ||||||
|                 if not 1 <= count <= 50: |                 if not 1 <= count <= 50: | ||||||
|                     raise ValueError("Message count must be between 1-50") |                     raise ValueError("Message count must be between 1-50") | ||||||
|                  |  | ||||||
|                 messages = [ |                 messages = [ | ||||||
|                     msg async for msg in interaction.channel.history(limit=count) |                     msg async for msg in interaction.channel.history(limit=count) | ||||||
|                 ][::-1] |                 ][::-1] | ||||||
| @@ -136,7 +118,7 @@ class IncidentModal(discord.ui.Modal): | |||||||
|                 color=0x00ff00 |                 color=0x00ff00 | ||||||
|             ) |             ) | ||||||
|             embed.add_field(name="Reason", value=self.reason.value[:500], inline=False) |             embed.add_field(name="Reason", value=self.reason.value[:500], inline=False) | ||||||
|              |  | ||||||
|             if messages: |             if messages: | ||||||
|                 preview = f"{messages[0].content[:100]}..." if len(messages[0].content) > 100 else messages[0].content |                 preview = f"{messages[0].content[:100]}..." if len(messages[0].content) > 100 else messages[0].content | ||||||
|                 embed.add_field(name="First Message", value=preview, inline=False) |                 embed.add_field(name="First Message", value=preview, inline=False) | ||||||
| @@ -155,6 +137,7 @@ class IncidentModal(discord.ui.Modal): | |||||||
|                 ephemeral=True |                 ephemeral=True | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class FollowupModal(discord.ui.Modal): | class FollowupModal(discord.ui.Modal): | ||||||
|     def __init__(self, incident_id: str): |     def __init__(self, incident_id: str): | ||||||
|         super().__init__(title=f"Follow-up: {incident_id}") |         super().__init__(title=f"Follow-up: {incident_id}") | ||||||
| @@ -173,7 +156,7 @@ class FollowupModal(discord.ui.Modal): | |||||||
|                 moderator_id=interaction.user.id, |                 moderator_id=interaction.user.id, | ||||||
|                 notes=self.notes.value |                 notes=self.notes.value | ||||||
|             ) |             ) | ||||||
|              |  | ||||||
|             if success: |             if success: | ||||||
|                 await interaction.response.send_message( |                 await interaction.response.send_message( | ||||||
|                     f"✅ Follow-up added to **{self.incident_id}**", |                     f"✅ Follow-up added to **{self.incident_id}**", | ||||||
| @@ -191,6 +174,7 @@ class FollowupModal(discord.ui.Modal): | |||||||
|                 ephemeral=True |                 ephemeral=True | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def setup(client): | async def setup(client): | ||||||
|     # Context menu command |     # Context menu command | ||||||
|     @client.tree.context_menu(name="Mark as Funny Moment") |     @client.tree.context_menu(name="Mark as Funny Moment") | ||||||
| @@ -259,13 +243,14 @@ async def setup(client): | |||||||
|                 await interaction.response.send_message("❌ Incident not found", ephemeral=True) |                 await interaction.response.send_message("❌ Incident not found", ephemeral=True) | ||||||
|                 return |                 return | ||||||
|  |  | ||||||
|             # Format messages |             # Format messages with proper timestamps | ||||||
|             messages = "\n\n".join( |             messages = "\n\n".join( | ||||||
|                 f"**<t:{int(msg['timestamp'].timestamp())}:F>** <@{msg['author_id']}>:\n{msg['content']}" |                 f"**<t:{int(msg['timestamp'].timestamp())}:F>** <@{msg['author_id']}>:\n" | ||||||
|  |                 f"{msg['content']}" | ||||||
|                 for msg in incident['messages'] |                 for msg in incident['messages'] | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             # Format follow-ups |             # Get follow-ups with proper timestamps | ||||||
|             followups = db.get_followups(incident_id) |             followups = db.get_followups(incident_id) | ||||||
|             followup_text = "\n\n".join( |             followup_text = "\n\n".join( | ||||||
|                 f"**<t:{int(f['timestamp'].timestamp())}:f>** <@{f['moderator_id']}>:\n{f['notes'][:200]}" |                 f"**<t:{int(f['timestamp'].timestamp())}:f>** <@{f['moderator_id']}>:\n{f['notes'][:200]}" | ||||||
| @@ -283,8 +268,18 @@ async def setup(client): | |||||||
|                 ), |                 ), | ||||||
|                 color=0xff0000 |                 color=0xff0000 | ||||||
|             ) |             ) | ||||||
|             embed.add_field(name="Messages", value=messages[:1020] + "..." if len(messages) > 1024 else messages, inline=False) |  | ||||||
|             embed.add_field(name=f"Follow-ups ({len(followups)})", value=followup_text[:1020] + "..." if len(followup_text) > 1024 else followup_text, inline=False) |             embed.add_field( | ||||||
|  |                 name="Messages", | ||||||
|  |                 value=messages[:1020] + "..." if len(messages) > 1024 else messages, | ||||||
|  |                 inline=False | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             embed.add_field( | ||||||
|  |                 name=f"Follow-ups ({len(followups)})", | ||||||
|  |                 value=followup_text[:1020] + "..." if len(followup_text) > 1024 else followup_text, | ||||||
|  |                 inline=False | ||||||
|  |             ) | ||||||
|  |  | ||||||
|             await interaction.response.send_message(embed=embed, ephemeral=True) |             await interaction.response.send_message(embed=embed, ephemeral=True) | ||||||
|  |  | ||||||
| @@ -361,4 +356,71 @@ async def setup(client): | |||||||
|                 choices.append(app_commands.Choice(name=name, value=inc['id'])) |                 choices.append(app_commands.Choice(name=name, value=inc['id'])) | ||||||
|         return choices[:25] |         return choices[:25] | ||||||
|  |  | ||||||
|     client.tree.add_command(moments_group) |     # Global error handler for commands | ||||||
|  |     @client.tree.error | ||||||
|  |     async def on_command_error(interaction: discord.Interaction, error): | ||||||
|  |         """Handle unauthorized access attempts""" | ||||||
|  |         if isinstance(error, MissingPermissions): | ||||||
|  |             # Log unauthorized access | ||||||
|  |             db.log_unauthorized_access( | ||||||
|  |                 user_id=interaction.user.id, | ||||||
|  |                 command_used=interaction.command.name if interaction.command else "unknown", | ||||||
|  |                 details=f"Attempted params: {interaction.data}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             await interaction.response.send_message( | ||||||
|  |                 "⛔ You don't have permission to use this command.", | ||||||
|  |                 ephemeral=True | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             logging.error(f"Command error: {error}", exc_info=True) | ||||||
|  |             await interaction.response.send_message( | ||||||
|  |                 "⚠️ An unexpected error occurred. This has been logged.", | ||||||
|  |                 ephemeral=True | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     @moments_group.command( | ||||||
|  |         name="audit", | ||||||
|  |         description="View unauthorized access attempts (Admin only)" | ||||||
|  |     ) | ||||||
|  |     @app_commands.describe(days="Lookback period in days (max 30)") | ||||||
|  |     @app_commands.checks.has_permissions(administrator=True) | ||||||
|  |     async def view_audit_log(interaction: discord.Interaction, days: int = 7): | ||||||
|  |         try: | ||||||
|  |             if days > 30 or days < 1: | ||||||
|  |                 raise ValueError("Lookback period must be 1-30 days") | ||||||
|  |  | ||||||
|  |             cutoff = datetime.now() - timedelta(days=days) | ||||||
|  |  | ||||||
|  |             with db._get_connection() as conn: | ||||||
|  |                 conn.row_factory = sqlite3.Row | ||||||
|  |                 cursor = conn.cursor() | ||||||
|  |                 cursor.execute(""" | ||||||
|  |                     SELECT * FROM unauthorized_access | ||||||
|  |                     WHERE timestamp > ? | ||||||
|  |                     ORDER BY timestamp DESC | ||||||
|  |                 """, (cutoff.isoformat(),))  # Store and compare ISO format | ||||||
|  |  | ||||||
|  |             results = [dict(row) for row in cursor.fetchall()] | ||||||
|  |  | ||||||
|  |             if not results: | ||||||
|  |                 await interaction.response.send_message("✅ No unauthorized access attempts found", ephemeral=True) | ||||||
|  |                 return | ||||||
|  |  | ||||||
|  |             log_text = "\n".join( | ||||||
|  |                 f"<t:{int(datetime.fromisoformat(row['timestamp']).timestamp())}:f> | " | ||||||
|  |                 f"<@{row['user_id']}> tried `{row['command_used']}`" | ||||||
|  |                 for row in results | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             embed = discord.Embed( | ||||||
|  |                 title=f"Unauthorized Access Logs ({days} days)", | ||||||
|  |                 description=log_text[:4000], | ||||||
|  |                 color=0xff0000 | ||||||
|  |             ) | ||||||
|  |             await interaction.response.send_message(embed=embed, ephemeral=True) | ||||||
|  |  | ||||||
|  |         except Exception as e: | ||||||
|  |             await interaction.response.send_message(f"❌ Error: {str(e)}", ephemeral=True) | ||||||
|  |  | ||||||
|  |     client.tree.add_command(moments_group) | ||||||
|   | |||||||
| @@ -4,11 +4,6 @@ from datetime import datetime | |||||||
| from typing import List, Dict, Optional | from typing import List, Dict, Optional | ||||||
|  |  | ||||||
|  |  | ||||||
| import sqlite3 |  | ||||||
| import logging |  | ||||||
| from datetime import datetime |  | ||||||
| from typing import List, Dict, Optional |  | ||||||
|  |  | ||||||
| class Database: | class Database: | ||||||
|     def __init__(self, db_path: str = "data/moments.db"): |     def __init__(self, db_path: str = "data/moments.db"): | ||||||
|         self.db_path = db_path |         self.db_path = db_path | ||||||
| @@ -21,6 +16,16 @@ class Database: | |||||||
|             conn.execute("DROP TABLE IF EXISTS incidents") |             conn.execute("DROP TABLE IF EXISTS incidents") | ||||||
|             conn.execute("DROP TABLE IF EXISTS incident_messages") |             conn.execute("DROP TABLE IF EXISTS incident_messages") | ||||||
|  |  | ||||||
|  |             conn.execute(""" | ||||||
|  |                 CREATE TABLE IF NOT EXISTS unauthorized_access ( | ||||||
|  |                     id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||||
|  |                     user_id INTEGER NOT NULL, | ||||||
|  |                     command_used TEXT NOT NULL, | ||||||
|  |                     timestamp DATETIME NOT NULL, | ||||||
|  |                     details TEXT | ||||||
|  |                 ) | ||||||
|  |             """) | ||||||
|  |  | ||||||
|             conn.execute(""" |             conn.execute(""" | ||||||
|                 CREATE TABLE IF NOT EXISTS funny_moments ( |                 CREATE TABLE IF NOT EXISTS funny_moments ( | ||||||
|                     id INTEGER PRIMARY KEY AUTOINCREMENT, |                     id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||||
| @@ -131,16 +136,17 @@ class Database: | |||||||
|             conn.row_factory = sqlite3.Row |             conn.row_factory = sqlite3.Row | ||||||
|             cursor = conn.cursor() |             cursor = conn.cursor() | ||||||
|  |  | ||||||
|             # Get incident details and parse timestamp |             # Get incident details | ||||||
|             cursor.execute("SELECT * FROM incidents WHERE id = ?", (incident_id,)) |             cursor.execute("SELECT * FROM incidents WHERE id = ?", (incident_id,)) | ||||||
|             incident = cursor.fetchone() |             incident = cursor.fetchone() | ||||||
|             if not incident: |             if not incident: | ||||||
|                 return None |                 return None | ||||||
|  |  | ||||||
|  |             # Convert timestamp string to datetime object | ||||||
|             incident_details = dict(incident) |             incident_details = dict(incident) | ||||||
|             incident_details['timestamp'] = datetime.fromisoformat(incident_details['timestamp'])  # Convert string to datetime |             incident_details['timestamp'] = datetime.fromisoformat(incident_details['timestamp']) | ||||||
|  |  | ||||||
|             # Get messages with parsed timestamps |             # Get related messages | ||||||
|             cursor.execute("SELECT * FROM incident_messages WHERE incident_id = ?", (incident_id,)) |             cursor.execute("SELECT * FROM incident_messages WHERE incident_id = ?", (incident_id,)) | ||||||
|             messages = [ |             messages = [ | ||||||
|                 {**dict(msg), 'timestamp': datetime.fromisoformat(msg['timestamp'])}  |                 {**dict(msg), 'timestamp': datetime.fromisoformat(msg['timestamp'])}  | ||||||
| @@ -180,7 +186,7 @@ class Database: | |||||||
|             return False |             return False | ||||||
|  |  | ||||||
|     def get_followups(self, incident_id: str) -> List[Dict]: |     def get_followups(self, incident_id: str) -> List[Dict]: | ||||||
|         """Get follow-ups with proper timestamps""" |         """Retrieve follow-ups with proper timestamps""" | ||||||
|         with self._get_connection() as conn: |         with self._get_connection() as conn: | ||||||
|             conn.row_factory = sqlite3.Row |             conn.row_factory = sqlite3.Row | ||||||
|             cursor = conn.cursor() |             cursor = conn.cursor() | ||||||
| @@ -189,3 +195,18 @@ class Database: | |||||||
|                 {**dict(row), 'timestamp': datetime.fromisoformat(row['timestamp'])} |                 {**dict(row), 'timestamp': datetime.fromisoformat(row['timestamp'])} | ||||||
|                 for row in cursor.fetchall() |                 for row in cursor.fetchall() | ||||||
|             ] |             ] | ||||||
|  |  | ||||||
|  |     def log_unauthorized_access(self, user_id: int, command_used: str, details: str = ""): | ||||||
|  |         """Log unauthorized command attempts""" | ||||||
|  |         try: | ||||||
|  |             with self._get_connection() as conn: | ||||||
|  |                 conn.execute(""" | ||||||
|  |                     INSERT INTO unauthorized_access  | ||||||
|  |                     (user_id, command_used, timestamp, details) | ||||||
|  |                     VALUES (?, ?, ?, ?) | ||||||
|  |                 """, (user_id, command_used, datetime.now(), details)) | ||||||
|  |                 conn.commit() | ||||||
|  |                 return True | ||||||
|  |         except Exception as e: | ||||||
|  |             logging.error(f"Failed to log unauthorized access: {e}") | ||||||
|  |             return False | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user