RubyOld/config.py

20 lines
509 B
Python

import os
import torch
from dotenv import load_dotenv
load_dotenv()
class Config:
model_dim = int(os.getenv("MODEL_DIM", 256))
num_layers = int(os.getenv("NUM_LAYERS", 4))
num_heads = int(os.getenv("HEADS", 8))
vocab_size = int(os.getenv("VOCAB_SIZE", 30000))
context_size = int(os.getenv("CONTEXT_SIZE", 512))
batch_size = int(os.getenv("BATCH_SIZE", 8))
lr = float(os.getenv("LEARNING_RATE", 1e-4))
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg = Config()