20 lines
509 B
Python
20 lines
509 B
Python
import os
|
|
import torch
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
|
|
class Config:
|
|
model_dim = int(os.getenv("MODEL_DIM", 256))
|
|
num_layers = int(os.getenv("NUM_LAYERS", 4))
|
|
num_heads = int(os.getenv("HEADS", 8))
|
|
vocab_size = int(os.getenv("VOCAB_SIZE", 30000))
|
|
context_size = int(os.getenv("CONTEXT_SIZE", 512))
|
|
batch_size = int(os.getenv("BATCH_SIZE", 8))
|
|
lr = float(os.getenv("LEARNING_RATE", 1e-4))
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
cfg = Config()
|