Source code for pettingzoo.utils.env

from __future__ import annotations

import warnings
from typing import Any, Dict, Generic, Iterable, Iterator, TypeVar

import gymnasium.spaces
import numpy as np

ObsType = TypeVar("ObsType")
ActionType = TypeVar("ActionType")
AgentID = TypeVar("AgentID")

# deprecated
ObsDict = Dict[AgentID, ObsType]

# deprecated
ActionDict = Dict[AgentID, ActionType]

"""
Base environment definitions

See docs/api.md for api documentation
See docs/dev_docs.md for additional documentation and an example environment.
"""


[docs] class AECEnv(Generic[AgentID, ObsType, ActionType]): """The AECEnv steps agents one at a time. If you are unsure if you have implemented a AECEnv correctly, try running the `api_test` documented in the Developer documentation on the website. """ metadata: dict[str, Any] # Metadata for the environment # All agents that may appear in the environment possible_agents: list[AgentID] agents: list[AgentID] # Agents active at any given time observation_spaces: dict[ AgentID, gymnasium.spaces.Space ] # Observation space for each agent # Action space for each agent action_spaces: dict[AgentID, gymnasium.spaces.Space] # Whether each agent has just reached a terminal state terminations: dict[AgentID, bool] truncations: dict[AgentID, bool] rewards: dict[AgentID, float] # Reward from the last step for each agent # Cumulative rewards for each agent _cumulative_rewards: dict[AgentID, float] infos: dict[ AgentID, dict[str, Any] ] # Additional information from the last step for each agent agent_selection: AgentID # The agent currently being stepped def __init__(self): pass
[docs] def step(self, action: ActionType) -> None: """Accepts and executes the action of the current agent_selection in the environment. Automatically switches control to the next agent. """ raise NotImplementedError
[docs] def reset( self, seed: int | None = None, options: dict | None = None, ) -> None: """Resets the environment to a starting state.""" raise NotImplementedError
# TODO: Remove `Optional` type below
[docs] def observe(self, agent: AgentID) -> ObsType | None: """Returns the observation an agent currently can make. `last()` calls this function. """ raise NotImplementedError
[docs] def render(self) -> None | np.ndarray | str | list: """Renders the environment as specified by self.render_mode. Render mode can be `human` to display a window. Other render modes in the default environments are `'rgb_array'` which returns a numpy array and is supported by all environments outside of classic, and `'ansi'` which returns the strings printed (specific to classic environments). """ raise NotImplementedError
def state(self) -> np.ndarray: """State returns a global view of the environment. It is appropriate for centralized training decentralized execution methods like QMIX """ raise NotImplementedError( "state() method has not been implemented in the environment {}.".format( self.metadata.get("name", self.__class__.__name__) ) )
[docs] def close(self): """Closes any resources that should be released. Closes the rendering window, subprocesses, network connections, or any other resources that should be released. """ pass
def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space: """Takes in agent and returns the observation space for that agent. MUST return the same value for the same agent name Default implementation is to return the observation_spaces dict """ warnings.warn( "Your environment should override the observation_space function. Attempting to use the observation_spaces dict attribute." ) return self.observation_spaces[agent] def action_space(self, agent: AgentID) -> gymnasium.spaces.Space: """Takes in agent and returns the action space for that agent. MUST return the same value for the same agent name Default implementation is to return the action_spaces dict """ warnings.warn( "Your environment should override the action_space function. Attempting to use the action_spaces dict attribute." ) return self.action_spaces[agent] @property def num_agents(self) -> int: return len(self.agents) @property def max_num_agents(self) -> int: return len(self.possible_agents) def _deads_step_first(self) -> AgentID: """Makes .agent_selection point to first terminated agent. Stores old value of agent_selection so that _was_dead_step can restore the variable after the dead agent steps. """ _deads_order = [ agent for agent in self.agents if (self.terminations[agent] or self.truncations[agent]) ] if _deads_order: self._skip_agent_selection = self.agent_selection self.agent_selection = _deads_order[0] return self.agent_selection def _clear_rewards(self) -> None: """Clears all items in .rewards.""" for agent in self.rewards: self.rewards[agent] = 0 def _accumulate_rewards(self) -> None: """Adds .rewards dictionary to ._cumulative_rewards dictionary. Typically called near the end of a step() method """ for agent, reward in self.rewards.items(): self._cumulative_rewards[agent] += reward def agent_iter(self, max_iter: int = 2**63) -> AECIterable: """Yields the current agent (self.agent_selection). Needs to be used in a loop where you step() each iteration. """ return AECIterable(self, max_iter) def last( self, observe: bool = True ) -> tuple[ObsType | None, float, bool, bool, dict[str, Any]]: """Returns observation, cumulative reward, terminated, truncated, info for the current agent (specified by self.agent_selection).""" agent = self.agent_selection assert agent is not None observation = self.observe(agent) if observe else None return ( observation, self._cumulative_rewards[agent], self.terminations[agent], self.truncations[agent], self.infos[agent], ) def _was_dead_step(self, action: ActionType) -> None: """Helper function that performs step() for dead agents. Does the following: 1. Removes dead agent from .agents, .terminations, .truncations, .rewards, ._cumulative_rewards, and .infos 2. Loads next agent into .agent_selection: if another agent is dead, loads that one, otherwise load next live agent 3. Clear the rewards dict Examples: Highly recommended to use at the beginning of step as follows: def step(self, action): if (self.terminations[self.agent_selection] or self.truncations[self.agent_selection]): self._was_dead_step() return # main contents of step """ if action is not None: raise ValueError("when an agent is dead, the only valid action is None") # removes dead agent agent = self.agent_selection assert ( self.terminations[agent] or self.truncations[agent] ), "an agent that was not dead as attempted to be removed" del self.terminations[agent] del self.truncations[agent] del self.rewards[agent] del self._cumulative_rewards[agent] del self.infos[agent] self.agents.remove(agent) # finds next dead agent or loads next live agent (Stored in _skip_agent_selection) _deads_order = [ agent for agent in self.agents if (self.terminations[agent] or self.truncations[agent]) ] if _deads_order: if getattr(self, "_skip_agent_selection", None) is None: self._skip_agent_selection = self.agent_selection self.agent_selection = _deads_order[0] else: if getattr(self, "_skip_agent_selection", None) is not None: assert self._skip_agent_selection is not None self.agent_selection = self._skip_agent_selection self._skip_agent_selection = None self._clear_rewards() def __str__(self) -> str: """Returns a name which looks like: `space_invaders_v1`.""" if hasattr(self, "metadata"): return self.metadata.get("name", self.__class__.__name__) else: return self.__class__.__name__ @property def unwrapped(self) -> AECEnv[AgentID, ObsType, ActionType]: return self
class AECIterable(Iterable[AgentID], Generic[AgentID, ObsType, ActionType]): def __init__(self, env, max_iter): self.env = env self.max_iter = max_iter def __iter__(self) -> AECIterator[AgentID, ObsType, ActionType]: return AECIterator(self.env, self.max_iter) class AECIterator(Iterator[AgentID], Generic[AgentID, ObsType, ActionType]): def __init__(self, env: AECEnv[AgentID, ObsType, ActionType], max_iter: int): self.env = env self.iters_til_term = max_iter def __next__(self) -> AgentID: if not self.env.agents or self.iters_til_term <= 0: raise StopIteration self.iters_til_term -= 1 return self.env.agent_selection def __iter__(self) -> AECIterator[AgentID, ObsType, ActionType]: return self
[docs] class ParallelEnv(Generic[AgentID, ObsType, ActionType]): """Parallel environment class. It steps every live agent at once. If you are unsure if you have implemented a ParallelEnv correctly, try running the `parallel_api_test` in the Developer documentation on the website. """ metadata: dict[str, Any] agents: list[AgentID] possible_agents: list[AgentID] observation_spaces: dict[ AgentID, gymnasium.spaces.Space ] # Observation space for each agent action_spaces: dict[AgentID, gymnasium.spaces.Space]
[docs] def reset( self, seed: int | None = None, options: dict | None = None, ) -> tuple[dict[AgentID, ObsType], dict[AgentID, dict]]: """Resets the environment. And returns a dictionary of observations (keyed by the agent name) """ raise NotImplementedError
[docs] def step( self, actions: dict[AgentID, ActionType] ) -> tuple[ dict[AgentID, ObsType], dict[AgentID, float], dict[AgentID, bool], dict[AgentID, bool], dict[AgentID, dict], ]: """Receives a dictionary of actions keyed by the agent name. Returns the observation dictionary, reward dictionary, terminated dictionary, truncated dictionary and info dictionary, where each dictionary is keyed by the agent. """ raise NotImplementedError
[docs] def render(self) -> None | np.ndarray | str | list: """Displays a rendered frame from the environment, if supported. Alternate render modes in the default environments are `'rgb_array'` which returns a numpy array and is supported by all environments outside of classic, and `'ansi'` which returns the strings printed (specific to classic environments). """ raise NotImplementedError
[docs] def close(self): """Closes the rendering window.""" pass
[docs] def state(self) -> np.ndarray: """Returns the state. State returns a global view of the environment appropriate for centralized training decentralized execution methods like QMIX """ raise NotImplementedError( "state() method has not been implemented in the environment {}.".format( self.metadata.get("name", self.__class__.__name__) ) )
[docs] def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space: """Takes in agent and returns the observation space for that agent. MUST return the same value for the same agent name Default implementation is to return the observation_spaces dict """ warnings.warn( "Your environment should override the observation_space function. Attempting to use the observation_spaces dict attribute." ) return self.observation_spaces[agent]
[docs] def action_space(self, agent: AgentID) -> gymnasium.spaces.Space: """Takes in agent and returns the action space for that agent. MUST return the same value for the same agent name Default implementation is to return the action_spaces dict """ warnings.warn( "Your environment should override the action_space function. Attempting to use the action_spaces dict attribute." ) return self.action_spaces[agent]
@property def num_agents(self) -> int: return len(self.agents) @property def max_num_agents(self) -> int: return len(self.possible_agents) def __str__(self) -> str: """Returns the name. Which looks like: "space_invaders_v1" by default """ if hasattr(self, "metadata"): return self.metadata.get("name", self.__class__.__name__) else: return self.__class__.__name__ @property def unwrapped(self) -> ParallelEnv: return self