from __future__ import annotations
from typing import Any
import numpy as np
from pettingzoo.utils.env import (
    ActionType,
    AECEnv,
    AECIterable,
    AECIterator,
    AgentID,
    ObsType,
)
from pettingzoo.utils.env_logger import EnvLogger
from pettingzoo.utils.wrappers.base import BaseWrapper
[docs]
class OrderEnforcingWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
    """Checks if function calls or attribute access are in a disallowed order.
    * error on getting rewards, terminations, truncations, infos, agent_selection before reset
    * error on calling step, observe before reset
    * error on iterating without stepping or resetting environment.
    * warn on calling close before render or reset
    * warn on calling step after environment is terminated or truncated
    """
    def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
        assert isinstance(
            env, AECEnv
        ), "OrderEnforcingWrapper is only compatible with AEC environments"
        self._has_reset = False
        self._has_rendered = False
        self._has_updated = False
        super().__init__(env)
    def __getattr__(self, value: str) -> Any:
        """Raises an error message when data is gotten from the env.
        Should only be gotten after reset
        """
        if value == "unwrapped":
            return self.env.unwrapped
        elif value == "render_mode" and hasattr(self.env, "render_mode"):
            return self.env.render_mode  # pyright: ignore[reportGeneralTypeIssues]
        elif value == "possible_agents":
            EnvLogger.error_possible_agents_attribute_missing("possible_agents")
        elif value == "observation_spaces":
            raise AttributeError(
                "The base environment does not have an possible_agents attribute. Use the environments `observation_space` method instead"
            )
        elif value == "action_spaces":
            raise AttributeError(
                "The base environment does not have an possible_agents attribute. Use the environments `action_space` method instead"
            )
        elif value == "agent_order":
            raise AttributeError(
                "agent_order has been removed from the API. Please consider using agent_iter instead."
            )
        elif value in {
            "rewards",
            "terminations",
            "truncations",
            "infos",
            "agent_selection",
            "num_agents",
            "agents",
        }:
            raise AttributeError(f"{value} cannot be accessed before reset")
        else:
            raise AttributeError(
                f"'{type(self).__name__}' object has no attribute '{value}'"
            )
    def render(self) -> None | np.ndarray | str | list:
        if not self._has_reset:
            EnvLogger.error_render_before_reset()
        self._has_rendered = True
        return super().render()
    def step(self, action: ActionType) -> None:
        if not self._has_reset:
            EnvLogger.error_step_before_reset()
        elif not self.agents:
            self._has_updated = True
            EnvLogger.warn_step_after_terminated_truncated()
            return None
        else:
            self._has_updated = True
            super().step(action)
    def observe(self, agent: AgentID) -> ObsType | None:
        if not self._has_reset:
            EnvLogger.error_observe_before_reset()
        return super().observe(agent)
    def state(self) -> np.ndarray:
        if not self._has_reset:
            EnvLogger.error_state_before_reset()
        return super().state()
    def agent_iter(
        self, max_iter: int = 2**63
    ) -> AECOrderEnforcingIterable[AgentID, ObsType, ActionType]:
        if not self._has_reset:
            EnvLogger.error_agent_iter_before_reset()
        return AECOrderEnforcingIterable(self, max_iter)
    def reset(self, seed: int | None = None, options: dict | None = None) -> None:
        self._has_reset = True
        self._has_updated = True
        super().reset(seed=seed, options=options)
    def __str__(self) -> str:
        if hasattr(self, "metadata"):
            return (
                str(self.env)
                if self.__class__ is OrderEnforcingWrapper
                else f"{type(self).__name__}<{str(self.env)}>"
            )
        else:
            return repr(self) 
class AECOrderEnforcingIterable(AECIterable[AgentID, ObsType, ActionType]):
    def __iter__(self) -> AECOrderEnforcingIterator[AgentID, ObsType, ActionType]:
        return AECOrderEnforcingIterator(self.env, self.max_iter)
class AECOrderEnforcingIterator(AECIterator[AgentID, ObsType, ActionType]):
    def __next__(self) -> AgentID:
        agent = super().__next__()
        assert hasattr(
            self.env, "_has_updated"
        ), "env must be wrapped by OrderEnforcingWrapper"
        assert (
            self.env._has_updated  # pyright: ignore[reportGeneralTypeIssues]
        ), "need to call step() or reset() in a loop over `agent_iter`"
        self.env._has_updated = False  # pyright: ignore[reportGeneralTypeIssues]
        return agent