Updated context and increased her brain capacity
This commit is contained in:
parent
6ccb52dc72
commit
26fbf85a90
@ -45,10 +45,10 @@ class TransformerBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TinyTransformer(nn.Module):
|
class TinyTransformer(nn.Module):
|
||||||
def __init__(self, vocab_size, embed_dim=256, depth=4, heads=8):
|
def __init__(self, vocab_size, embed_dim=512, depth=4, heads=16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.token_embed = nn.Embedding(vocab_size, embed_dim)
|
self.token_embed = nn.Embedding(vocab_size, embed_dim)
|
||||||
self.pos_embed = nn.Parameter(torch.randn(1, 128, embed_dim))
|
self.pos_embed = nn.Parameter(torch.randn(1, 256, embed_dim))
|
||||||
self.blocks = nn.Sequential(*[TransformerBlock(embed_dim, heads) for _ in range(depth)])
|
self.blocks = nn.Sequential(*[TransformerBlock(embed_dim, heads) for _ in range(depth)])
|
||||||
self.norm = nn.LayerNorm(embed_dim)
|
self.norm = nn.LayerNorm(embed_dim)
|
||||||
self.head = nn.Linear(embed_dim, vocab_size)
|
self.head = nn.Linear(embed_dim, vocab_size)
|
||||||
|
@ -35,7 +35,7 @@ def train_on_message(text: str, source: str = "user"):
|
|||||||
try:
|
try:
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
context_texts = get_recent_context(10)
|
context_texts = get_recent_context(30)
|
||||||
augmented_text = "<start> " + " ".join(context_texts + [text]) + " <end>"
|
augmented_text = "<start> " + " ".join(context_texts + [text]) + " <end>"
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(augmented_text)
|
tokens = tokenizer.tokenize(augmented_text)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user