First Commit:
Merged two projects into one
This commit is contained in:
commit
567c2d5f84
15
.github/workflows/discord_sync.yml
vendored
Normal file
15
.github/workflows/discord_sync.yml
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
name: Discord Webhook
|
||||||
|
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
git:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Run Discord Webhook
|
||||||
|
uses: johnnyhuy/actions-discord-git-webhook@main
|
||||||
|
with:
|
||||||
|
webhook_url: ${{ secrets.YOUR_DISCORD_WEBHOOK_URL }}
|
164
.gitignore
vendored
Normal file
164
.gitignore
vendored
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
/client_secret.json
|
||||||
|
/token.json
|
278
main.py
Normal file
278
main.py
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import discord
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from model import JadeModel
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from collections import deque
|
||||||
|
import uuid as uuid_lib
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
SCOPES = ['https://www.googleapis.com/auth/youtube.readonly']
|
||||||
|
DATABASE_FILE = 'global_user_data.db' # Updated database file name
|
||||||
|
CHANNEL_HANDLE = 'UCsVJcf4KbO8Vz308EKpSYxw'
|
||||||
|
STREAM_KEYWORD = "Live"
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
intents = discord.Intents.default()
|
||||||
|
intents.messages = True
|
||||||
|
intents.message_content = True
|
||||||
|
client = discord.Client(intents=intents)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = JadeModel().to(device)
|
||||||
|
|
||||||
|
# Context management for conversation continuity
|
||||||
|
conversation_history = deque(maxlen=5) # Store the last 5 messages for context
|
||||||
|
training_data = [] # Store live messages for training purposes
|
||||||
|
|
||||||
|
# Profile Manager
|
||||||
|
class ProfileManager:
|
||||||
|
def __init__(self):
|
||||||
|
self._create_profiles_table()
|
||||||
|
|
||||||
|
def _create_profiles_table(self):
|
||||||
|
conn = sqlite3.connect(DATABASE_FILE)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS global_profiles (
|
||||||
|
uuid TEXT PRIMARY KEY,
|
||||||
|
discord_user_id TEXT UNIQUE,
|
||||||
|
youtube_channel_id TEXT UNIQUE,
|
||||||
|
points INTEGER DEFAULT 0,
|
||||||
|
last_interaction TIMESTAMP,
|
||||||
|
subscription_status TEXT,
|
||||||
|
first_seen_as_member TIMESTAMP,
|
||||||
|
has_opted_in INTEGER DEFAULT 0
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def get_or_create_uuid(self, discord_id=None, youtube_id=None):
|
||||||
|
conn = sqlite3.connect(DATABASE_FILE)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
uuid = None
|
||||||
|
|
||||||
|
if discord_id:
|
||||||
|
cursor.execute("SELECT uuid FROM global_profiles WHERE discord_user_id = ?", (discord_id,))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
uuid = result[0]
|
||||||
|
|
||||||
|
if not uuid and youtube_id:
|
||||||
|
cursor.execute("SELECT uuid FROM global_profiles WHERE youtube_channel_id = ?", (youtube_id,))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
uuid = result[0]
|
||||||
|
|
||||||
|
if not uuid:
|
||||||
|
uuid = str(uuid_lib.uuid4())
|
||||||
|
cursor.execute('''
|
||||||
|
INSERT INTO global_profiles (uuid, discord_user_id, youtube_channel_id)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
''', (uuid, discord_id, youtube_id))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
return uuid
|
||||||
|
|
||||||
|
def update_subscription_status(self, youtube_id, status):
|
||||||
|
conn = sqlite3.connect(DATABASE_FILE)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('''
|
||||||
|
UPDATE global_profiles
|
||||||
|
SET subscription_status = ?, last_interaction = ?
|
||||||
|
WHERE youtube_channel_id = ?
|
||||||
|
''', (status, datetime.utcnow(), youtube_id))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def delete_user_data(self, uuid):
|
||||||
|
# Delete user data to comply with GDPR
|
||||||
|
conn = sqlite3.connect(DATABASE_FILE)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('SELECT * FROM global_profiles WHERE uuid = ?', (uuid,))
|
||||||
|
user_data = cursor.fetchone()
|
||||||
|
if user_data:
|
||||||
|
with open(f'deleted_user_data_{uuid}.json', 'w') as f:
|
||||||
|
json.dump({
|
||||||
|
'uuid': user_data[0],
|
||||||
|
'discord_user_id': user_data[1],
|
||||||
|
'youtube_channel_id': user_data[2],
|
||||||
|
'points': user_data[3],
|
||||||
|
'last_interaction': user_data[4],
|
||||||
|
'subscription_status': user_data[5],
|
||||||
|
'first_seen_as_member': user_data[6],
|
||||||
|
'has_opted_in': user_data[7]
|
||||||
|
}, f)
|
||||||
|
cursor.execute('DELETE FROM global_profiles WHERE uuid = ?', (uuid,))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def has_opted_in(self, uuid):
|
||||||
|
conn = sqlite3.connect(DATABASE_FILE)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('SELECT has_opted_in FROM global_profiles WHERE uuid = ?', (uuid,))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
conn.close()
|
||||||
|
return result and result[0] == 1
|
||||||
|
|
||||||
|
def set_opt_in(self, uuid, opted_in=True):
|
||||||
|
conn = sqlite3.connect(DATABASE_FILE)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('''
|
||||||
|
UPDATE global_profiles
|
||||||
|
SET has_opted_in = ?
|
||||||
|
WHERE uuid = ?
|
||||||
|
''', (1 if opted_in else 0, uuid))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
profile_manager = ProfileManager()
|
||||||
|
|
||||||
|
# YouTube API Functions
|
||||||
|
def get_authenticated_service():
|
||||||
|
flow = InstalledAppFlow.from_client_secrets_file(
|
||||||
|
'client_secret.json', SCOPES)
|
||||||
|
creds = flow.run_local_server(port=63355)
|
||||||
|
with open('token.json', 'w') as token:
|
||||||
|
token.write(creds.to_json())
|
||||||
|
return build('youtube', 'v3', credentials=creds)
|
||||||
|
|
||||||
|
def find_correct_live_video(youtube, channel_id, keyword):
|
||||||
|
request = youtube.search().list(
|
||||||
|
part="snippet",
|
||||||
|
channelId=channel_id,
|
||||||
|
eventType="live",
|
||||||
|
type="video"
|
||||||
|
)
|
||||||
|
response = request.execute()
|
||||||
|
items = response.get('items', [])
|
||||||
|
for item in items:
|
||||||
|
title = item['snippet']['title']
|
||||||
|
if keyword.lower() in title.lower():
|
||||||
|
return item['id']['videoId']
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_live_chat_id(youtube, video_id):
|
||||||
|
request = youtube.videos().list(
|
||||||
|
part="liveStreamingDetails",
|
||||||
|
id=video_id
|
||||||
|
)
|
||||||
|
response = request.execute()
|
||||||
|
items = response.get('items', [])
|
||||||
|
if items:
|
||||||
|
return items[0]['liveStreamingDetails'].get('activeLiveChatId')
|
||||||
|
return None
|
||||||
|
|
||||||
|
def monitor_youtube_chat(youtube, live_chat_id):
|
||||||
|
if not live_chat_id:
|
||||||
|
print("No valid live chat ID found.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
next_page_token = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
request = youtube.liveChatMessages().list(
|
||||||
|
liveChatId=live_chat_id,
|
||||||
|
part="snippet,authorDetails",
|
||||||
|
maxResults=200,
|
||||||
|
pageToken=next_page_token
|
||||||
|
)
|
||||||
|
response = request.execute()
|
||||||
|
|
||||||
|
if 'items' in response and response['items']:
|
||||||
|
for item in response['items']:
|
||||||
|
user_id = item['authorDetails']['channelId']
|
||||||
|
display_name = item['authorDetails']['displayName']
|
||||||
|
is_moderator = item['authorDetails']['isChatModerator']
|
||||||
|
is_member = item['authorDetails']['isChatSponsor']
|
||||||
|
message = item['snippet']['displayMessage']
|
||||||
|
|
||||||
|
uuid = profile_manager.get_or_create_uuid(youtube_id=user_id)
|
||||||
|
if is_member:
|
||||||
|
profile_manager.update_subscription_status(user_id, "subscribed")
|
||||||
|
else:
|
||||||
|
profile_manager.update_subscription_status(user_id, "none")
|
||||||
|
|
||||||
|
print(f"[{datetime.utcnow()}] {display_name}: {message} (UUID: {uuid})")
|
||||||
|
|
||||||
|
# Add live chat message to training data if the user has opted in
|
||||||
|
if profile_manager.has_opted_in(uuid):
|
||||||
|
training_data.append((display_name, message))
|
||||||
|
|
||||||
|
next_page_token = response.get('nextPageToken')
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("No new messages detected; continuing to poll...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error while monitoring chat: {e}")
|
||||||
|
time.sleep(30) # Wait before retrying in case of an error
|
||||||
|
|
||||||
|
time.sleep(10) # Adjust this delay as needed
|
||||||
|
|
||||||
|
# Discord Event Handlers
|
||||||
|
@client.event
|
||||||
|
async def on_ready():
|
||||||
|
print(f'We have logged in as {client.user}')
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def on_message(message):
|
||||||
|
if message.author == client.user:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Link the Discord user to the correct global profile UUID
|
||||||
|
uuid = profile_manager.get_or_create_uuid(discord_id=str(message.author.id))
|
||||||
|
|
||||||
|
# Ensure user has opted in before interacting
|
||||||
|
if not profile_manager.has_opted_in(uuid):
|
||||||
|
await message.channel.send("Please type '!optin' to confirm that you agree to data usage and interaction with this bot.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.content.lower() == '!optin':
|
||||||
|
profile_manager.set_opt_in(uuid, True)
|
||||||
|
await message.channel.send("You have successfully opted in to data usage.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add the message to conversation history for context
|
||||||
|
conversation_history.append(message.content)
|
||||||
|
|
||||||
|
# Generate a response using Jade with context
|
||||||
|
context = "\n".join(conversation_history)
|
||||||
|
response = model.generate_response(context)
|
||||||
|
if response:
|
||||||
|
await message.channel.send(response)
|
||||||
|
|
||||||
|
print(f"Discord Interaction: User {message.author} (UUID: {uuid})")
|
||||||
|
|
||||||
|
# Main Function to Start Both Services
|
||||||
|
def main():
|
||||||
|
youtube = get_authenticated_service()
|
||||||
|
channel_id = profile_manager.get_or_create_uuid(youtube_id=CHANNEL_HANDLE)
|
||||||
|
video_id = find_correct_live_video(youtube, channel_id, STREAM_KEYWORD)
|
||||||
|
if video_id:
|
||||||
|
live_chat_id = get_live_chat_id(youtube, video_id)
|
||||||
|
if live_chat_id:
|
||||||
|
print("Monitoring YouTube live chat...")
|
||||||
|
monitor_youtube_chat(youtube, live_chat_id)
|
||||||
|
else:
|
||||||
|
print("No live chat ID available.")
|
||||||
|
else:
|
||||||
|
print("Could not find the correct live stream or it is not live.")
|
||||||
|
|
||||||
|
client.run(os.getenv('DISCORD_TOKEN'))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
118
model.py
Normal file
118
model.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import os
|
||||||
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
|
|
||||||
|
class JadeModel(nn.Module):
|
||||||
|
def __init__(self, load_model_path=None):
|
||||||
|
super(JadeModel, self).__init__()
|
||||||
|
# GPT-like Transformer architecture
|
||||||
|
self.vocab_size = 512 # Character-level tokenization (ASCII range)
|
||||||
|
self.embedding_dim = 768 # GPT-like embedding dimension
|
||||||
|
self.num_heads = 12 # Number of attention heads
|
||||||
|
self.num_layers = 12 # Number of transformer layers
|
||||||
|
self.max_position_embeddings = 512 # Maximum sequence length
|
||||||
|
|
||||||
|
# Embedding layers
|
||||||
|
self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
|
||||||
|
self.position_embedding = nn.Embedding(self.max_position_embeddings, self.embedding_dim)
|
||||||
|
|
||||||
|
# Transformer layers
|
||||||
|
self.transformer_layers = nn.ModuleList([
|
||||||
|
nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.num_heads)
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Output layer
|
||||||
|
self.fc = nn.Linear(self.embedding_dim, self.vocab_size)
|
||||||
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
# Device setup
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
# Load model state if path is provided
|
||||||
|
if load_model_path and os.path.exists(load_model_path):
|
||||||
|
self.load_model(load_model_path)
|
||||||
|
print(f"Model loaded from {load_model_path}")
|
||||||
|
|
||||||
|
def forward(self, input_ids):
|
||||||
|
# Truncate input_ids if longer than max_position_embeddings
|
||||||
|
if input_ids.size(1) > self.max_position_embeddings:
|
||||||
|
input_ids = input_ids[:, -self.max_position_embeddings:]
|
||||||
|
|
||||||
|
# Create position ids for input sequence
|
||||||
|
seq_length = input_ids.size(1)
|
||||||
|
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=self.device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||||
|
|
||||||
|
# Embedding lookup
|
||||||
|
x = self.embedding(input_ids) + self.position_embedding(position_ids)
|
||||||
|
|
||||||
|
# Pass through transformer layers
|
||||||
|
for layer in self.transformer_layers:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
# Output layer
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def generate_response(self, input_text, initial_temperature=0.85, top_p=0.8, repetition_penalty=1.4, max_token_frequency=2, max_length=50, min_response_length=5):
|
||||||
|
# Convert input_text to token ids
|
||||||
|
input_ids = self.tokenize(input_text)
|
||||||
|
if len(input_ids) > self.max_position_embeddings:
|
||||||
|
input_ids = input_ids[-self.max_position_embeddings:] # Truncate if too long
|
||||||
|
input_tensor = torch.tensor(input_ids).unsqueeze(0).to(self.device)
|
||||||
|
generated_tokens = input_ids.copy() # Start with input tokens to use as context
|
||||||
|
temperature = initial_temperature
|
||||||
|
recent_tokens = list(input_ids[-10:]) # Expanded recent tokens window to 10
|
||||||
|
|
||||||
|
with torch.no_grad(), autocast():
|
||||||
|
for _ in range(max_length): # Generate up to max_length more tokens
|
||||||
|
output = self.forward(input_tensor)
|
||||||
|
logits = output[:, -1, :] # Consider only the last token's logits
|
||||||
|
logits = logits / (temperature + 1e-2) # Apply temperature for sampling diversity
|
||||||
|
|
||||||
|
# Apply repetition penalty
|
||||||
|
for token in set(generated_tokens):
|
||||||
|
if generated_tokens.count(token) > 1:
|
||||||
|
logits[0, token] /= (repetition_penalty + generated_tokens.count(token) * 0.02) # Frequency-based scaling for penalty
|
||||||
|
|
||||||
|
# Dynamic Nucleus (top-p) sampling with adjusted threshold
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cumulative_probs = torch.cumsum(self.softmax(sorted_logits), dim=-1)
|
||||||
|
top_p_mask = cumulative_probs < top_p
|
||||||
|
top_p_logits = sorted_logits[top_p_mask]
|
||||||
|
top_p_indices = sorted_indices[top_p_mask]
|
||||||
|
|
||||||
|
if len(top_p_logits) > 1:
|
||||||
|
top_p_probs = self.softmax(top_p_logits)
|
||||||
|
sampled_token = top_p_indices[torch.multinomial(top_p_probs, num_samples=1).item()].item()
|
||||||
|
else:
|
||||||
|
sampled_token = sorted_indices[0, 0].item() # Fallback to the most probable token if none pass the top-p threshold
|
||||||
|
|
||||||
|
# Add token and update state
|
||||||
|
generated_tokens.append(sampled_token)
|
||||||
|
if len(recent_tokens) > 10:
|
||||||
|
recent_tokens.pop(0) # Maintain a window of recent tokens to suppress
|
||||||
|
|
||||||
|
# Update input tensor to include the generated token
|
||||||
|
input_tensor = torch.tensor(generated_tokens).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
# Gradually decrease temperature to reduce randomness more smoothly
|
||||||
|
temperature = max(0.75, temperature * 0.98)
|
||||||
|
|
||||||
|
response = self.detokenize(generated_tokens[len(input_ids):]) # Exclude the input from the response
|
||||||
|
return response if len(response.strip()) > 0 else None
|
||||||
|
|
||||||
|
def load_model(self, path):
|
||||||
|
self.load_state_dict(torch.load(path, map_location=self.device))
|
||||||
|
|
||||||
|
# Placeholder tokenization method (to be replaced with optimized tokenizer)
|
||||||
|
def tokenize(self, text):
|
||||||
|
return [ord(c) for c in text]
|
||||||
|
|
||||||
|
# Placeholder detokenization method (to be replaced with optimized tokenizer)
|
||||||
|
def detokenize(self, tokens):
|
||||||
|
return ''.join([chr(t) for t in tokens])
|
Loading…
x
Reference in New Issue
Block a user