from __future__ import annotations
import os
from typing import Any
import gymnasium.spaces
import numpy as np
from pettingzoo.utils.env import AECEnv, AgentID, ParallelEnv
def _check_observation_saveable(
    env: AECEnv[AgentID, Any, Any] | ParallelEnv[AgentID, Any, Any], agent: AgentID
) -> None:
    obs_space = env.observation_space(agent)
    assert isinstance(
        obs_space, gymnasium.spaces.Box
    ), "Observations must be Box to save observations as image"
    assert np.all(np.equal(obs_space.low, 0)) and np.all(
        np.equal(obs_space.high, 255)
    ), "Observations must be 0 to 255 to save as image"
    assert (
        len(obs_space.shape) == 3 or len(obs_space.shape) == 2
    ), "Observations must be 2D or 3D to save as image"
    if len(obs_space.shape) == 3:
        assert (
            obs_space.shape[2] == 1 or obs_space.shape[2] == 3
        ), "3D observations can only have 1 or 3 channels to save as an image"
# save the observation of an agent. If agent not specified uses env selected agent. If all_agents
# then all agents in environment observation recorded.
[docs]
def save_observation(
    env: AECEnv[AgentID, Any, Any],
    agent: AgentID | None = None,
    all_agents: bool = False,
    save_dir: str = os.getcwd(),
) -> None:
    from PIL import Image
    if agent is None:
        agent = env.agent_selection
    agent_list = [agent]
    if all_agents:
        agent_list = env.agents[:]
    for a in agent_list:
        _check_observation_saveable(env, a)
        save_folder = "{}/{}".format(
            save_dir, str(env).replace("<", "_").replace(">", "_")
        )
        os.makedirs(save_folder, exist_ok=True)
        # Parallel envs don't have observe method
        observation = env.observe(a)
        assert (
            observation is not None
        ), "Observation must be different than None to save as an image"
        rescaled = observation.astype(np.uint8)
        im = Image.fromarray(rescaled)
        fname = os.path.join(save_folder, str(a) + ".png")
        im.save(fname)