Fix: Added code to allow for other data sources to be added
This commit is contained in:
@ -118,7 +118,7 @@ class GPT(nn.Module):
|
||||
loss = F.cross_entropy(logits, targets)
|
||||
return logits, loss
|
||||
|
||||
def generate(self, idx, max_new_tokens, temperature=1.0):
|
||||
def generate(self, idx, max_new_tokens, temperature):
|
||||
for _ in range(max_new_tokens):
|
||||
idx_cond = idx[:, -block_size:]
|
||||
logits, _ = self(idx_cond)
|
||||
|
Reference in New Issue
Block a user