import os
import torch
from torch.distributions.categorical import Categorical
from models import PolicyNet
import random
from agents.base import Agent
from moviepy import ImageSequenceClip


class PGAgent(Agent):
    """Interface-compatible Policy Gradient self."""
    agent_key = "pg"
    def __init__(self, config, env_name, device=None):
        super().__init__(config=config, env_name=env_name, device=device)


    def select_action(self, state):
        return 0, {}

    def store(self, *args, **kwargs):
        pass

    def update(self):
        return {"loss": 0}

    def train(self, state, episode=None):
        terminated = False
        truncated = False
        total_reward = 0
        while not terminated and not truncated:
            action, info = self.select_action(state)
            next_s, reward, terminated, truncated, _ = self.env.step(action)
            next_state = torch.tensor(next_s, device=self.device).float()
            self.store()

            state = next_state
            total_reward += reward

        # end episode updates
        metrics = self.update()
        return {
            'reward': total_reward,
            'loss': metrics['loss']
        }

    def save(self, path):
        pass

    def load(self, path):
        pass

    def evaluate(self, out_path, num_episodes=1):
        rewards = []
        for episode in range(num_episodes):
            state, _ = self.env.reset()
            state = torch.tensor(state, device=self.device).float()
            total = 0
            terminated = False
            truncated = False
            frames = []
            while not terminated and not truncated:
                frames.append(self.env.render())
                action, _ = self.select_action(state)
                next_state, reward, terminated, truncated, _ = self.env.step(action)
                state = torch.tensor(next_state, device=self.device).float()
                total += reward
            rewards.append(total)
            clip = ImageSequenceClip(sequence=frames, fps=24)
            clip.write_videofile(os.path.join(out_path, f"./{self.agent_key}_evaluate_{episode}.mp4"), codec="libx264")
        return rewards
