Source code for pettingzoo.utils.save_observation

from __future__ import annotations

import os
from typing import Any

import gymnasium.spaces
import numpy as np

from pettingzoo.utils.env import AECEnv, AgentID, ParallelEnv

def _check_observation_saveable(
    env: AECEnv[AgentID, Any, Any] | ParallelEnv[AgentID, Any, Any], agent: AgentID
) -> None:
    obs_space = env.observation_space(agent)
    assert isinstance(
        obs_space, gymnasium.spaces.Box
    ), "Observations must be Box to save observations as image"
    assert np.all(np.equal(obs_space.low, 0)) and np.all(
        np.equal(obs_space.high, 255)
    ), "Observations must be 0 to 255 to save as image"
    assert (
        len(obs_space.shape) == 3 or len(obs_space.shape) == 2
    ), "Observations must be 2D or 3D to save as image"
    if len(obs_space.shape) == 3:
        assert (
            obs_space.shape[2] == 1 or obs_space.shape[2] == 3
        ), "3D observations can only have 1 or 3 channels to save as an image"

# save the observation of an agent. If agent not specified uses env selected agent. If all_agents
# then all agents in environment observation recorded.
[docs] def save_observation( env: AECEnv[AgentID, Any, Any], agent: AgentID | None = None, all_agents: bool = False, save_dir: str = os.getcwd(), ) -> None: from PIL import Image if agent is None: agent = env.agent_selection agent_list = [agent] if all_agents: agent_list = env.agents[:] for a in agent_list: _check_observation_saveable(env, a) save_folder = "{}/{}".format( save_dir, str(env).replace("<", "_").replace(">", "_") ) os.makedirs(save_folder, exist_ok=True) # Parallel envs don't have observe method observation = env.observe(a) assert ( observation is not None ), "Observation must be different than None to save as an image" rescaled = observation.astype(np.uint8) im = Image.fromarray(rescaled) fname = os.path.join(save_folder, str(a) + ".png")