Updated context and increased her brain capacity
This commit is contained in:
parent
6ccb52dc72
commit
26fbf85a90
@ -45,10 +45,10 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class TinyTransformer(nn.Module):
|
||||
def __init__(self, vocab_size, embed_dim=256, depth=4, heads=8):
|
||||
def __init__(self, vocab_size, embed_dim=512, depth=4, heads=16):
|
||||
super().__init__()
|
||||
self.token_embed = nn.Embedding(vocab_size, embed_dim)
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, 128, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, 256, embed_dim))
|
||||
self.blocks = nn.Sequential(*[TransformerBlock(embed_dim, heads) for _ in range(depth)])
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, vocab_size)
|
||||
|
@ -35,7 +35,7 @@ def train_on_message(text: str, source: str = "user"):
|
||||
try:
|
||||
model.train()
|
||||
|
||||
context_texts = get_recent_context(10)
|
||||
context_texts = get_recent_context(30)
|
||||
augmented_text = "<start> " + " ".join(context_texts + [text]) + " <end>"
|
||||
|
||||
tokens = tokenizer.tokenize(augmented_text)
|
||||
|
Loading…
x
Reference in New Issue
Block a user