Fixed dynamic expand
This commit is contained in:
parent
5f74b2c64c
commit
a9b4871420
@ -3,6 +3,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from model.memory import save_dream
|
from model.memory import save_dream
|
||||||
from model.brain_state import model, tokenizer, DEVICE
|
from model.brain_state import model, tokenizer, DEVICE
|
||||||
|
from model.journal import record_to_journal
|
||||||
|
from model.trainer import train_on_message
|
||||||
from context.context import get_recent_context
|
from context.context import get_recent_context
|
||||||
|
|
||||||
recent_dreams = []
|
recent_dreams = []
|
||||||
@ -59,9 +61,7 @@ def daydream():
|
|||||||
|
|
||||||
if score > 0.45:
|
if score > 0.45:
|
||||||
save_dream(sentence, score)
|
save_dream(sentence, score)
|
||||||
from model.journal import record_to_journal
|
|
||||||
record_to_journal(sentence)
|
record_to_journal(sentence)
|
||||||
from model.trainer import train_on_message
|
|
||||||
train_on_message(sentence)
|
train_on_message(sentence)
|
||||||
|
|
||||||
if len(recent_dreams) > 10:
|
if len(recent_dreams) > 10:
|
||||||
|
@ -4,20 +4,24 @@ from model.brain_state import model, tokenizer, DEVICE
|
|||||||
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||||
|
|
||||||
|
_last_expansion_vocab_size = 0
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer():
|
def get_optimizer():
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def expand_model_if_needed():
|
def expand_model_if_needed():
|
||||||
global model, optimizer
|
global model, optimizer, _last_expansion_vocab_size
|
||||||
|
|
||||||
current_vocab_size = len(tokenizer.vocab) + 10
|
current_vocab_size = len(tokenizer.vocab) + 10
|
||||||
old_vocab_size = model.head.out_features
|
|
||||||
|
|
||||||
|
if current_vocab_size - _last_expansion_vocab_size < 5:
|
||||||
|
return # Only expand every 5 words
|
||||||
|
|
||||||
|
old_vocab_size = model.head.out_features
|
||||||
if current_vocab_size <= old_vocab_size:
|
if current_vocab_size <= old_vocab_size:
|
||||||
return # No expansion needed
|
return # No expansion needed
|
||||||
|
|
||||||
print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}")
|
||||||
|
|
||||||
old_state = model.state_dict()
|
old_state = model.state_dict()
|
||||||
|
@ -12,7 +12,7 @@ def log_loss(value: float):
|
|||||||
f.write(f"{time.time()},{round(value, 4)}\n")
|
f.write(f"{time.time()},{round(value, 4)}\n")
|
||||||
|
|
||||||
|
|
||||||
def train_on_message(text: str):
|
def train_on_message(text: str, source: str = "user"):
|
||||||
expand_model_if_needed()
|
expand_model_if_needed()
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
@ -45,4 +45,4 @@ def train_on_message(text: str):
|
|||||||
opt.step()
|
opt.step()
|
||||||
|
|
||||||
log_loss(loss.item())
|
log_loss(loss.item())
|
||||||
add_to_context(text)
|
add_to_context(text, source=source)
|
||||||
|
@ -46,6 +46,6 @@ async def read_books_forever():
|
|||||||
save_progress(progress)
|
save_progress(progress)
|
||||||
|
|
||||||
if is_valid_line(line):
|
if is_valid_line(line):
|
||||||
train_on_message(line)
|
train_on_message(line, source="book")
|
||||||
set_next_action(READ_DELAY, "Reading")
|
set_next_action(READ_DELAY, "Reading")
|
||||||
await asyncio.sleep(READ_DELAY)
|
await asyncio.sleep(READ_DELAY)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user