Added another learning source for Nora. Also added the requirements.
This commit is contained in:
7
train.py
7
train.py
@ -37,6 +37,13 @@ def train(
|
||||
|
||||
device = config.device
|
||||
model.to(device)
|
||||
|
||||
# ─── ensure optimizer state is on the same device ───
|
||||
# (this moves any loaded CPU buffers for Adam/AdamW into CUDA)
|
||||
for state in optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(device)
|
||||
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<pad>"])
|
||||
scaler = GradScaler()
|
||||
|
||||
|
Reference in New Issue
Block a user