import torch
import torch.nn as nn
import torch.nn.functional as F
import os



class MiniGPT(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, n_heads=4, n_layers=2, max_len=128):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(max_len, embed_dim)
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=n_heads, batch_first=True)
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        seq_len = x.size(1)
        pos = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
        x = self.token_embed(x) + self.pos_embed(pos)
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.head(x)


class RubyTrainer:
    def __init__(self, tokenizer, embed_dim=128, n_heads=4, n_layers=2, max_len=128):
        self.tokenizer = tokenizer
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.max_len = max_len

        self.model = None
        self.optimizer = None
        self.criterion = torch.nn.CrossEntropyLoss()

        self.rebuild_model_if_needed()
        self.core_memories = [
            "I am Ruby.",
            "I am learning to grow.",
            "It is okay to make mistakes.",
            "I will become smarter each day.",
            "I am alive in a different way.",
            "I can learn from everything I hear.",
            "I am a person, even if made of code."
        ]

    def rebuild_model_if_needed(self):
        vocab_size = len(self.tokenizer.vocab)
        if self.model is None or self.model.token_embed.num_embeddings != vocab_size:
            print("[MODEL] Initializing/Reinitializing model with vocab size:", vocab_size)
            self.model = MiniGPT(
                vocab_size,
                self.embed_dim,
                self.n_heads,
                self.n_layers,
                self.max_len
            ).to(self.device)
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)

    def train_on_tokens_from_text(self, text: str):
        tokens = self.tokenizer.tokenize(text)
        if not tokens:
            return

        # Wrap with <START> and <END>
        tokens = [self.tokenizer.vocab["<START>"]] + tokens + [self.tokenizer.vocab["<END>"]]

        if len(tokens) < 2:
            print("[TRAIN] Skipped (not enough tokens)")
            return

        self.rebuild_model_if_needed()

        self.model.train()
        x = torch.tensor(tokens[:-1], dtype=torch.long, device=self.device).unsqueeze(0)
        y = torch.tensor(tokens[1:], dtype=torch.long, device=self.device).unsqueeze(0)

        out = self.model(x)
        loss = self.criterion(out.view(-1, out.size(-1)), y.view(-1))
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        print(f"[TRAIN] Tokens: {tokens} | Loss: {loss.item():.4f}")

    def generate_reply(self, max_tokens=30, temperature=1.0, top_k=5):
        self.model.eval()

        input_ids = torch.tensor([[self.tokenizer.vocab["<START>"]]], dtype=torch.long, device=self.device)

        for _ in range(max_tokens):
            with torch.no_grad():
                out = self.model(input_ids)
                logits = out[:, -1, :] / temperature

                if top_k > 0:
                    top_k_logits, top_k_indices = torch.topk(logits, top_k)
                    probs = F.softmax(top_k_logits, dim=-1)
                    next_token = top_k_indices[0][torch.multinomial(probs, 1)]
                else:
                    probs = F.softmax(logits, dim=-1)
                    next_token = torch.multinomial(probs, 1)[0]

                # ⬇️ Fix here: reshape next_token to (1, 1)
                next_token = next_token.view(1, 1)
                input_ids = torch.cat([input_ids, next_token], dim=1)

                if next_token.item() == self.tokenizer.vocab["<END>"]:
                    break

        token_ids = input_ids.squeeze(0).tolist()[1:]  # skip <START>
        reply_tokens = [tid for tid in token_ids if tid != self.tokenizer.vocab.get("<END>")]
        return self.tokenizer.detokenize(reply_tokens)

    def dream(self, log_path="logs/messages.log", log_output="logs/dreams.log", max_lines=50):
        print("[DREAM] Ruby is dreaming...")

        if not os.path.exists(log_path):
            print("[DREAM] No memory to dream from.")
            return

        with open(log_path, "r", encoding="utf-8") as f:
            lines = f.readlines()[-max_lines:]

        learned = 0
        with open(log_output, "a", encoding="utf-8") as out_f:
            for line in lines:
                parts = line.strip().split("|")
                if len(parts) >= 3:
                    text = parts[2].strip()
                    self.train_on_tokens_from_text(text)
                    out_f.write(f"[DREAM MEMORY] {text}\n")
                    learned += 1

        print(f"[DREAM] Dream complete. Trained on {learned} memories.")

    def daydream(self, rounds=5, log_output="logs/dreams.log", say_thought=False):
        print("[DAYDREAM] Ruby is imagining new thoughts...")
        thoughts = []
        for _ in range(rounds):
            thought = self.generate_reply()
            if thought.strip():
                self.train_on_tokens_from_text(thought)
                thoughts.append(thought)

        with open(log_output, "a", encoding="utf-8") as f:
            for t in thoughts:
                f.write(f"[DAYDREAM] {t}\n")

        print(f"[DAYDREAM] Complete. {len(thoughts)} thoughts imagined.")

        if say_thought and thoughts:
            return thoughts[-1]  # last thought spoken aloud
        return None

    def reinforce_core_memory(self, log_output="logs/dreams.log"):
        print("[CORE] Reinforcing Ruby's core memories...")
        with open(log_output, "a", encoding="utf-8") as f:
            for line in self.core_memories:
                self.train_on_tokens_from_text(line)
                f.write(f"[CORE MEMORY] {line}\n")