Source code for pettingzoo.utils.wrappers.base

from __future__ import annotations

import warnings
from typing import Any

import gymnasium.spaces
import numpy as np

from pettingzoo.utils.env import ActionType, AECEnv, ObsType


[docs]class BaseWrapper(AECEnv): """Creates a wrapper around `env` parameter. All AECEnv wrappers should inherit from this base class """ def __init__(self, env: AECEnv): super().__init__() self.env = env try: self.possible_agents = self.env.possible_agents except AttributeError: pass self.metadata = self.env.metadata # we don't want these defined as we don't want them used before they are gotten # self.agent_selection = self.env.agent_selection # self.rewards = self.env.rewards # self.dones = self.env.dones # we don't want to care one way or the other whether environments have an infos or not before reset try: self.infos = self.env.infos except AttributeError: pass # Not every environment has the .state_space attribute implemented try: self.state_space = ( self.env.state_space # pyright: ignore[reportGeneralTypeIssues] ) except AttributeError: pass def __getattr__(self, name: str) -> Any: """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" if name.startswith("_"): raise AttributeError(f"accessing private attribute '{name}' is prohibited") return getattr(self.env, name) @property def observation_spaces(self) -> dict[str, gymnasium.spaces.Space]: warnings.warn( "The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead." ) try: return { agent: self.observation_space(agent) for agent in self.possible_agents } except AttributeError as e: raise AttributeError( "The base environment does not have an `observation_spaces` dict attribute. Use the environment's `observation_space` method instead" ) from e @property def action_spaces(self) -> dict[str, gymnasium.spaces.Space]: warnings.warn( "The `action_spaces` dictionary is deprecated. Use the `action_space` function instead." ) try: return {agent: self.action_space(agent) for agent in self.possible_agents} except AttributeError as e: raise AttributeError( "The base environment does not have an action_spaces dict attribute. Use the environment's `action_space` method instead" ) from e def observation_space(self, agent: str) -> gymnasium.spaces.Space: return self.env.observation_space(agent) def action_space(self, agent: str) -> gymnasium.spaces.Space: return self.env.action_space(agent) @property def unwrapped(self) -> AECEnv: return self.env.unwrapped def close(self) -> None: self.env.close() def render(self) -> None | np.ndarray | str | list: return self.env.render() def reset(self, seed: int | None = None, options: dict | None = None): self.env.reset(seed=seed, options=options) self.agent_selection = self.env.agent_selection self.rewards = self.env.rewards self.terminations = self.env.terminations self.truncations = self.env.truncations self.infos = self.env.infos self.agents = self.env.agents self._cumulative_rewards = self.env._cumulative_rewards def observe(self, agent: str) -> ObsType | None: return self.env.observe(agent) def state(self) -> np.ndarray: return self.env.state() def step(self, action: ActionType) -> None: self.env.step(action) self.agent_selection = self.env.agent_selection self.rewards = self.env.rewards self.terminations = self.env.terminations self.truncations = self.env.truncations self.infos = self.env.infos self.agents = self.env.agents self._cumulative_rewards = self.env._cumulative_rewards def __str__(self) -> str: """Returns a name which looks like: "max_observation<space_invaders_v1>".""" return f"{type(self).__name__}<{str(self.env)}>"