import os
import copy
import torch
import torch.nn as nn
from models import QNet
from replay_buffer import ReplayBuffer
from agents.base import Agent
from moviepy import ImageSequenceClip


class DQNAgent(Agent):
    agent_key = 'dqn'
    def __init__(self, config, env_name, device=None):
        super().__init__(config=config, env_name=env_name, device=device)

        self.epsilon = config.get("epsilon", 0.1)
        self.buffer_size = config.get("buffer_size", 10000)
        self.batch_size = config.get("batch_size", 32)
        self.sync_interval = config.get("sync_interval", 20)
        self.train_interval = config.get("train_interval", 1)

        self.replay = ReplayBuffer(self.buffer_size, self.batch_size, device=self.device)
        self.global_steps = 0


    def select_action(self, state):
        pass

    def store(self, *args, **kwargs):
        # states expected to be tensors already
        pass

    def update(self):
        return {"loss": 0}

    def train(self, state, episode):
        terminated = False
        truncated = False
        total_reward = 0
        while not terminated and not truncated:
            action, info = self.select_action(state)
            next_s, reward, terminated, truncated, _ = self.env.step(action)
            next_state = torch.tensor(next_s, device=self.device).float()
            self.store()
            # DQN update every few steps
            loss = self.update()

            state = next_state
            total_reward += reward

        # end episode updates
        # synchronize target network periodically
        if episode % self.sync_interval == 0:
            try:
                self.synchronize()
            except Exception:
                pass
        return {
            'reward': total_reward,
            'loss': loss['loss'] if loss else None
        }
    
    def save(self, path):
        pass

    def load(self, path):
        pass

    def evaluate(self, out_path, num_episodes=1):
        rewards = []
        for episode in range(num_episodes):
            state, _ = self.env.reset()
            state = torch.tensor(state, device=self.device).float()
            terminated = False
            truncated = False
            total = 0
            frames = []
            while not terminated and not truncated:
                frames.append(self.env.render())
                with torch.no_grad():
                    qs = self.qnet(state.to(self.device))
                    action = int(torch.argmax(qs).item())
                next_state, reward, terminated, truncated, _ = self.env.step(action)
                state = torch.tensor(next_state, device=self.device).float()
                total += reward
            rewards.append(total)
            clip = ImageSequenceClip(sequence=frames, fps=24)
            clip.write_videofile(os.path.join(out_path, f"./{self.agent_key}_evaluate_{episode}.mp4"), codec="libx264")
        return rewards
