fixing cuda errors
This commit is contained in:
parent
8d7cf38f1b
commit
c9601f132b
@ -1,4 +1,9 @@
|
||||
from flask import Flask, render_template
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
import logging
|
||||
from model.brainmap import get_brainmap
|
||||
from model.journal import read_journal_entries
|
||||
from model.memory import load_dreams
|
||||
@ -7,10 +12,6 @@ from model.abstraction import cluster_vocab
|
||||
from model.memory import load_dreams
|
||||
from model.scheduler import get_time_until_next_action, get_next_action_label
|
||||
from context.context import load_context
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
@ -126,4 +127,6 @@ def dreams():
|
||||
|
||||
|
||||
def run_dashboard():
|
||||
log = logging.getLogger('werkzeug')
|
||||
log.setLevel(logging.ERROR)
|
||||
app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False)
|
||||
|
4
main.py
4
main.py
@ -11,7 +11,6 @@ from model.rehearsal import simulate_conversation
|
||||
from model.scheduler import set_next_action
|
||||
from reader.reader import read_books_forever
|
||||
from dashboard.dashboard import run_dashboard
|
||||
import threading
|
||||
|
||||
load_dotenv()
|
||||
TOKEN = os.getenv("DISCORD_TOKEN")
|
||||
@ -38,9 +37,6 @@ async def on_message(message):
|
||||
response = generate_response()
|
||||
await message.channel.send(response)
|
||||
|
||||
# Launch Flask in background
|
||||
threading.Thread(target=run_dashboard, daemon=True).start()
|
||||
|
||||
|
||||
async def background_cleanup_loop():
|
||||
while True:
|
||||
|
@ -1,10 +1,12 @@
|
||||
import torch
|
||||
import threading
|
||||
from model.brain_architecture import TinyTransformer
|
||||
from model.brain_state import model, tokenizer, DEVICE
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
|
||||
_last_expansion_vocab_size = 0
|
||||
_last_vocab_size = 0
|
||||
_expand_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_optimizer():
|
||||
@ -12,28 +14,31 @@ def get_optimizer():
|
||||
|
||||
|
||||
def expand_model_if_needed():
|
||||
global model, optimizer, _last_expansion_vocab_size
|
||||
global model, optimizer, _last_vocab_size
|
||||
|
||||
current_vocab_size = len(tokenizer.vocab) + 10
|
||||
with _expand_lock:
|
||||
current_vocab_size = len(tokenizer.vocab) + 10
|
||||
|
||||
if current_vocab_size - _last_expansion_vocab_size < 5:
|
||||
return # Only expand every 5 words
|
||||
if current_vocab_size - _last_vocab_size < 10:
|
||||
return # Expand only after 10 new words collected
|
||||
|
||||
old_vocab_size = model.head.out_features
|
||||
if current_vocab_size <= old_vocab_size:
|
||||
return # No expansion needed
|
||||
print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||
old_vocab_size = model.head.out_features
|
||||
|
||||
old_state = model.state_dict()
|
||||
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
||||
if current_vocab_size <= old_vocab_size:
|
||||
return
|
||||
|
||||
# Transfer matching parameters
|
||||
with torch.no_grad():
|
||||
for name, param in new_model.named_parameters():
|
||||
if name in old_state and old_state[name].shape == param.shape:
|
||||
param.copy_(old_state[name])
|
||||
# print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||
|
||||
model = new_model
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
old_state = model.state_dict()
|
||||
new_model = TinyTransformer(vocab_size=current_vocab_size).to(DEVICE)
|
||||
|
||||
print("Expansion complete.")
|
||||
with torch.no_grad():
|
||||
for name, param in new_model.named_parameters():
|
||||
if name in old_state and old_state[name].shape == param.shape:
|
||||
param.copy_(old_state[name])
|
||||
|
||||
model = new_model
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
_last_vocab_size = current_vocab_size
|
||||
|
||||
# print("Expansion complete.")
|
Loading…
x
Reference in New Issue
Block a user