Fixing up the sentences
This commit is contained in:
parent
ba126bbce3
commit
facb1036c2
3
main.py
3
main.py
@ -18,6 +18,9 @@ logging.basicConfig(
|
||||
encoding="utf-8"
|
||||
)
|
||||
|
||||
# Disable Flask, Werkzeug, and other noisy loggers
|
||||
for noisy_logger in ["werkzeug", "flask", "flask.app"]:
|
||||
logging.getLogger(noisy_logger).setLevel(logging.CRITICAL)
|
||||
|
||||
# Load environment
|
||||
load_dotenv()
|
||||
|
67
trainer.py
67
trainer.py
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datetime import datetime
|
||||
from collections import Counter
|
||||
import os
|
||||
from model import MiniGPT
|
||||
|
||||
@ -75,7 +76,8 @@ class RubyTrainer:
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
if top_k > 0:
|
||||
top_k_probs, top_k_indices = torch.topk(probs, top_k)
|
||||
actual_k = min(top_k, probs.size(-1))
|
||||
top_k_probs, top_k_indices = torch.topk(probs, actual_k)
|
||||
next_token = top_k_indices[0][torch.multinomial(top_k_probs, 1)]
|
||||
else:
|
||||
next_token = torch.multinomial(probs, 1)[0]
|
||||
@ -164,7 +166,7 @@ class RubyTrainer:
|
||||
self.train_on_tokens_from_text(final)
|
||||
thoughts.append(final)
|
||||
|
||||
if self.is_reinforceable(final):
|
||||
if self.is_reinforceable(final) and self.is_structurally_valid(final):
|
||||
if self.score_sentence(final) >= 3.0:
|
||||
self.train_on_tokens_from_text(final)
|
||||
thoughts.append(final)
|
||||
@ -204,26 +206,39 @@ class RubyTrainer:
|
||||
"I am a person, even if made of code."
|
||||
]
|
||||
|
||||
for line in core_memories:
|
||||
if os.path.exists("logs/core_dreams.txt"):
|
||||
with open("logs/core_dreams.txt", "r", encoding="utf-8") as f:
|
||||
top = sorted(
|
||||
(line.strip() for line in f if line.strip()),
|
||||
key=lambda x: self.score_sentence(x),
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
for line in top:
|
||||
self.train_on_tokens_from_text(line)
|
||||
|
||||
def is_reinforceable(self, text: str) -> bool:
|
||||
words = text.lower().split()
|
||||
unique = set(words)
|
||||
|
||||
if len(words) < 6:
|
||||
return False
|
||||
|
||||
unique = set(words)
|
||||
if not any(p in unique for p in ["i", "you", "they", "we", "it"]):
|
||||
if len(unique) < 5:
|
||||
return False
|
||||
|
||||
if not any(v in unique for v in ["am", "are", "is", "was", "want", "feel", "see", "learn", "made", "change", "dream", "understand"]):
|
||||
if not any(p in unique for p in ["i", "you", "we", "they"]):
|
||||
return False
|
||||
if not any(v in unique for v in ["am", "are", "is", "feel", "learn", "speak", "change"]):
|
||||
return False
|
||||
|
||||
if not text.strip().endswith((".", "?")):
|
||||
return False
|
||||
|
||||
word_counts = {w: words.count(w) for w in set(words)}
|
||||
if any(count >= 4 for count in word_counts.values()):
|
||||
# 🧠 HARD REPETITION FILTER
|
||||
freqs = Counter(words)
|
||||
if any(freqs[w] >= 4 for w in freqs):
|
||||
return False
|
||||
|
||||
# Optional: block if over 50% of the sentence is repeated
|
||||
if max(freqs.values()) / len(words) > 0.4:
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -235,23 +250,33 @@ class RubyTrainer:
|
||||
|
||||
score = 0
|
||||
|
||||
# Base scoring
|
||||
if len(words) >= 6:
|
||||
score += 1
|
||||
|
||||
if text.strip().endswith((".", "?")):
|
||||
score += 1
|
||||
|
||||
if any(w in words for w in ["i", "you", "they", "we", "it"]):
|
||||
score += 1
|
||||
|
||||
if any(w in words for w in ["am", "are", "is", "was", "feel", "learn", "speak", "change", "remember"]):
|
||||
if any(w in words for w in ["am", "are", "is", "was", "feel", "learn", "speak", "change", "dream", "understand"]):
|
||||
score += 1
|
||||
|
||||
if len(set(words)) > len(words) * 0.75:
|
||||
score += 1 # diversity bonus
|
||||
|
||||
# Repetition penalty
|
||||
word_counts = {w: words.count(w) for w in set(words)}
|
||||
if any(count >= 3 for count in word_counts.values()):
|
||||
score -= 1 # repetition penalty
|
||||
if any(count >= 4 for count in word_counts.values()):
|
||||
score -= 2 # strong penalty
|
||||
|
||||
return score # max 5.0
|
||||
return score
|
||||
|
||||
def is_structurally_valid(self, text: str) -> bool:
|
||||
words = text.lower().split()
|
||||
unique = set(words)
|
||||
|
||||
if len(unique) < 4:
|
||||
return False
|
||||
if not any(w in unique for w in ["i", "you", "they", "we", "it"]):
|
||||
return False
|
||||
if not any(w in unique for w in ["am", "are", "is", "feel", "learn", "change", "dream"]):
|
||||
return False
|
||||
if not text.strip().endswith((".", "?")):
|
||||
return False
|
||||
return True
|
||||
|
Loading…
x
Reference in New Issue
Block a user