Fixed dynamic expand
This commit is contained in:
parent
5f74b2c64c
commit
a9b4871420
@ -3,6 +3,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from model.memory import save_dream
|
||||
from model.brain_state import model, tokenizer, DEVICE
|
||||
from model.journal import record_to_journal
|
||||
from model.trainer import train_on_message
|
||||
from context.context import get_recent_context
|
||||
|
||||
recent_dreams = []
|
||||
@ -59,9 +61,7 @@ def daydream():
|
||||
|
||||
if score > 0.45:
|
||||
save_dream(sentence, score)
|
||||
from model.journal import record_to_journal
|
||||
record_to_journal(sentence)
|
||||
from model.trainer import train_on_message
|
||||
train_on_message(sentence)
|
||||
|
||||
if len(recent_dreams) > 10:
|
||||
|
@ -4,20 +4,24 @@ from model.brain_state import model, tokenizer, DEVICE
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
|
||||
_last_expansion_vocab_size = 0
|
||||
|
||||
|
||||
def get_optimizer():
|
||||
return optimizer
|
||||
|
||||
|
||||
def expand_model_if_needed():
|
||||
global model, optimizer
|
||||
global model, optimizer, _last_expansion_vocab_size
|
||||
|
||||
current_vocab_size = len(tokenizer.vocab) + 10
|
||||
old_vocab_size = model.head.out_features
|
||||
|
||||
if current_vocab_size - _last_expansion_vocab_size < 5:
|
||||
return # Only expand every 5 words
|
||||
|
||||
old_vocab_size = model.head.out_features
|
||||
if current_vocab_size <= old_vocab_size:
|
||||
return # No expansion needed
|
||||
|
||||
print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||
|
||||
old_state = model.state_dict()
|
||||
|
@ -12,7 +12,7 @@ def log_loss(value: float):
|
||||
f.write(f"{time.time()},{round(value, 4)}\n")
|
||||
|
||||
|
||||
def train_on_message(text: str):
|
||||
def train_on_message(text: str, source: str = "user"):
|
||||
expand_model_if_needed()
|
||||
|
||||
model.train()
|
||||
@ -45,4 +45,4 @@ def train_on_message(text: str):
|
||||
opt.step()
|
||||
|
||||
log_loss(loss.item())
|
||||
add_to_context(text)
|
||||
add_to_context(text, source=source)
|
||||
|
@ -46,6 +46,6 @@ async def read_books_forever():
|
||||
save_progress(progress)
|
||||
|
||||
if is_valid_line(line):
|
||||
train_on_message(line)
|
||||
train_on_message(line, source="book")
|
||||
set_next_action(READ_DELAY, "Reading")
|
||||
await asyncio.sleep(READ_DELAY)
|
||||
|
Loading…
x
Reference in New Issue
Block a user