Fix: Added code to allow for other data sources to be added

This commit is contained in:
Dan
2024-06-08 09:21:19 -04:00
parent 1fe54ed1ff
commit e3e4b7abe6
3 changed files with 138 additions and 5 deletions

View File

@ -118,7 +118,7 @@ class GPT(nn.Module):
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens, temperature=1.0):
def generate(self, idx, max_new_tokens, temperature):
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, _ = self(idx_cond)