diff --git a/model/brain.py b/model/brain.py index 0d8ecf6..9646247 100644 --- a/model/brain.py +++ b/model/brain.py @@ -3,6 +3,8 @@ import torch import torch.nn.functional as F from model.memory import save_dream from model.brain_state import model, tokenizer, DEVICE +from model.journal import record_to_journal +from model.trainer import train_on_message from context.context import get_recent_context recent_dreams = [] @@ -59,9 +61,7 @@ def daydream(): if score > 0.45: save_dream(sentence, score) - from model.journal import record_to_journal record_to_journal(sentence) - from model.trainer import train_on_message train_on_message(sentence) if len(recent_dreams) > 10: diff --git a/model/dynamic_expand.py b/model/dynamic_expand.py index 5b69c68..7593ba6 100644 --- a/model/dynamic_expand.py +++ b/model/dynamic_expand.py @@ -4,20 +4,24 @@ from model.brain_state import model, tokenizer, DEVICE optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) +_last_expansion_vocab_size = 0 + def get_optimizer(): return optimizer def expand_model_if_needed(): - global model, optimizer + global model, optimizer, _last_expansion_vocab_size current_vocab_size = len(tokenizer.vocab) + 10 - old_vocab_size = model.head.out_features + if current_vocab_size - _last_expansion_vocab_size < 5: + return # Only expand every 5 words + + old_vocab_size = model.head.out_features if current_vocab_size <= old_vocab_size: return # No expansion needed - print(f"Expanding model from {old_vocab_size} -> {current_vocab_size}") old_state = model.state_dict() diff --git a/model/trainer.py b/model/trainer.py index 155298f..5cd6597 100644 --- a/model/trainer.py +++ b/model/trainer.py @@ -12,7 +12,7 @@ def log_loss(value: float): f.write(f"{time.time()},{round(value, 4)}\n") -def train_on_message(text: str): +def train_on_message(text: str, source: str = "user"): expand_model_if_needed() model.train() @@ -45,4 +45,4 @@ def train_on_message(text: str): opt.step() log_loss(loss.item()) - add_to_context(text) + add_to_context(text, source=source) diff --git a/reader/reader.py b/reader/reader.py index d8b0166..d14f1cf 100644 --- a/reader/reader.py +++ b/reader/reader.py @@ -46,6 +46,6 @@ async def read_books_forever(): save_progress(progress) if is_valid_line(line): - train_on_message(line) + train_on_message(line, source="book") set_next_action(READ_DELAY, "Reading") await asyncio.sleep(READ_DELAY)