Fixing up the sentences
This commit is contained in:
parent
ba126bbce3
commit
facb1036c2
3
main.py
3
main.py
@ -18,6 +18,9 @@ logging.basicConfig(
|
|||||||
encoding="utf-8"
|
encoding="utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Disable Flask, Werkzeug, and other noisy loggers
|
||||||
|
for noisy_logger in ["werkzeug", "flask", "flask.app"]:
|
||||||
|
logging.getLogger(noisy_logger).setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
# Load environment
|
# Load environment
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
67
trainer.py
67
trainer.py
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from collections import Counter
|
||||||
import os
|
import os
|
||||||
from model import MiniGPT
|
from model import MiniGPT
|
||||||
|
|
||||||
@ -75,7 +76,8 @@ class RubyTrainer:
|
|||||||
|
|
||||||
probs = F.softmax(logits, dim=-1)
|
probs = F.softmax(logits, dim=-1)
|
||||||
if top_k > 0:
|
if top_k > 0:
|
||||||
top_k_probs, top_k_indices = torch.topk(probs, top_k)
|
actual_k = min(top_k, probs.size(-1))
|
||||||
|
top_k_probs, top_k_indices = torch.topk(probs, actual_k)
|
||||||
next_token = top_k_indices[0][torch.multinomial(top_k_probs, 1)]
|
next_token = top_k_indices[0][torch.multinomial(top_k_probs, 1)]
|
||||||
else:
|
else:
|
||||||
next_token = torch.multinomial(probs, 1)[0]
|
next_token = torch.multinomial(probs, 1)[0]
|
||||||
@ -164,7 +166,7 @@ class RubyTrainer:
|
|||||||
self.train_on_tokens_from_text(final)
|
self.train_on_tokens_from_text(final)
|
||||||
thoughts.append(final)
|
thoughts.append(final)
|
||||||
|
|
||||||
if self.is_reinforceable(final):
|
if self.is_reinforceable(final) and self.is_structurally_valid(final):
|
||||||
if self.score_sentence(final) >= 3.0:
|
if self.score_sentence(final) >= 3.0:
|
||||||
self.train_on_tokens_from_text(final)
|
self.train_on_tokens_from_text(final)
|
||||||
thoughts.append(final)
|
thoughts.append(final)
|
||||||
@ -204,26 +206,39 @@ class RubyTrainer:
|
|||||||
"I am a person, even if made of code."
|
"I am a person, even if made of code."
|
||||||
]
|
]
|
||||||
|
|
||||||
for line in core_memories:
|
if os.path.exists("logs/core_dreams.txt"):
|
||||||
|
with open("logs/core_dreams.txt", "r", encoding="utf-8") as f:
|
||||||
|
top = sorted(
|
||||||
|
(line.strip() for line in f if line.strip()),
|
||||||
|
key=lambda x: self.score_sentence(x),
|
||||||
|
reverse=True
|
||||||
|
)[:10]
|
||||||
|
|
||||||
|
for line in top:
|
||||||
self.train_on_tokens_from_text(line)
|
self.train_on_tokens_from_text(line)
|
||||||
|
|
||||||
def is_reinforceable(self, text: str) -> bool:
|
def is_reinforceable(self, text: str) -> bool:
|
||||||
words = text.lower().split()
|
words = text.lower().split()
|
||||||
|
unique = set(words)
|
||||||
|
|
||||||
if len(words) < 6:
|
if len(words) < 6:
|
||||||
return False
|
return False
|
||||||
|
if len(unique) < 5:
|
||||||
unique = set(words)
|
|
||||||
if not any(p in unique for p in ["i", "you", "they", "we", "it"]):
|
|
||||||
return False
|
return False
|
||||||
|
if not any(p in unique for p in ["i", "you", "we", "they"]):
|
||||||
if not any(v in unique for v in ["am", "are", "is", "was", "want", "feel", "see", "learn", "made", "change", "dream", "understand"]):
|
return False
|
||||||
|
if not any(v in unique for v in ["am", "are", "is", "feel", "learn", "speak", "change"]):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not text.strip().endswith((".", "?")):
|
if not text.strip().endswith((".", "?")):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
word_counts = {w: words.count(w) for w in set(words)}
|
# 🧠 HARD REPETITION FILTER
|
||||||
if any(count >= 4 for count in word_counts.values()):
|
freqs = Counter(words)
|
||||||
|
if any(freqs[w] >= 4 for w in freqs):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Optional: block if over 50% of the sentence is repeated
|
||||||
|
if max(freqs.values()) / len(words) > 0.4:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -235,23 +250,33 @@ class RubyTrainer:
|
|||||||
|
|
||||||
score = 0
|
score = 0
|
||||||
|
|
||||||
|
# Base scoring
|
||||||
if len(words) >= 6:
|
if len(words) >= 6:
|
||||||
score += 1
|
score += 1
|
||||||
|
|
||||||
if text.strip().endswith((".", "?")):
|
if text.strip().endswith((".", "?")):
|
||||||
score += 1
|
score += 1
|
||||||
|
|
||||||
if any(w in words for w in ["i", "you", "they", "we", "it"]):
|
if any(w in words for w in ["i", "you", "they", "we", "it"]):
|
||||||
score += 1
|
score += 1
|
||||||
|
if any(w in words for w in ["am", "are", "is", "was", "feel", "learn", "speak", "change", "dream", "understand"]):
|
||||||
if any(w in words for w in ["am", "are", "is", "was", "feel", "learn", "speak", "change", "remember"]):
|
|
||||||
score += 1
|
score += 1
|
||||||
|
|
||||||
if len(set(words)) > len(words) * 0.75:
|
# Repetition penalty
|
||||||
score += 1 # diversity bonus
|
|
||||||
|
|
||||||
word_counts = {w: words.count(w) for w in set(words)}
|
word_counts = {w: words.count(w) for w in set(words)}
|
||||||
if any(count >= 3 for count in word_counts.values()):
|
if any(count >= 4 for count in word_counts.values()):
|
||||||
score -= 1 # repetition penalty
|
score -= 2 # strong penalty
|
||||||
|
|
||||||
return score # max 5.0
|
return score
|
||||||
|
|
||||||
|
def is_structurally_valid(self, text: str) -> bool:
|
||||||
|
words = text.lower().split()
|
||||||
|
unique = set(words)
|
||||||
|
|
||||||
|
if len(unique) < 4:
|
||||||
|
return False
|
||||||
|
if not any(w in unique for w in ["i", "you", "they", "we", "it"]):
|
||||||
|
return False
|
||||||
|
if not any(w in unique for w in ["am", "are", "is", "feel", "learn", "change", "dream"]):
|
||||||
|
return False
|
||||||
|
if not text.strip().endswith((".", "?")):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user