from __future__ import annotations
import warnings
from typing import Any, Dict, Generic, Iterable, Iterator, TypeVar
import gymnasium.spaces
import numpy as np
ObsType = TypeVar("ObsType")
ActionType = TypeVar("ActionType")
AgentID = TypeVar("AgentID")
# deprecated
ObsDict = Dict[AgentID, ObsType]
# deprecated
ActionDict = Dict[AgentID, ActionType]
"""
Base environment definitions
See docs/api.md for api documentation
See docs/dev_docs.md for additional documentation and an example environment.
"""
[docs]
class AECEnv(Generic[AgentID, ObsType, ActionType]):
"""The AECEnv steps agents one at a time.
If you are unsure if you have implemented a AECEnv correctly, try running
the `api_test` documented in the Developer documentation on the website.
"""
metadata: dict[str, Any] # Metadata for the environment
# All agents that may appear in the environment
possible_agents: list[AgentID]
agents: list[AgentID] # Agents active at any given time
observation_spaces: dict[
AgentID, gymnasium.spaces.Space
] # Observation space for each agent
# Action space for each agent
action_spaces: dict[AgentID, gymnasium.spaces.Space]
# Whether each agent has just reached a terminal state
terminations: dict[AgentID, bool]
truncations: dict[AgentID, bool]
rewards: dict[AgentID, float] # Reward from the last step for each agent
# Cumulative rewards for each agent
_cumulative_rewards: dict[AgentID, float]
infos: dict[
AgentID, dict[str, Any]
] # Additional information from the last step for each agent
agent_selection: AgentID # The agent currently being stepped
def __init__(self):
pass
[docs]
def step(self, action: ActionType) -> None:
"""Accepts and executes the action of the current agent_selection in the environment.
Automatically switches control to the next agent.
"""
raise NotImplementedError
[docs]
def reset(
self,
seed: int | None = None,
options: dict | None = None,
) -> None:
"""Resets the environment to a starting state."""
raise NotImplementedError
# TODO: Remove `Optional` type below
[docs]
def observe(self, agent: AgentID) -> ObsType | None:
"""Returns the observation an agent currently can make.
`last()` calls this function.
"""
raise NotImplementedError
[docs]
def render(self) -> None | np.ndarray | str | list:
"""Renders the environment as specified by self.render_mode.
Render mode can be `human` to display a window.
Other render modes in the default environments are `'rgb_array'`
which returns a numpy array and is supported by all environments outside of classic,
and `'ansi'` which returns the strings printed (specific to classic environments).
"""
raise NotImplementedError
def state(self) -> np.ndarray:
"""State returns a global view of the environment.
It is appropriate for centralized training decentralized execution methods like QMIX
"""
raise NotImplementedError(
"state() method has not been implemented in the environment {}.".format(
self.metadata.get("name", self.__class__.__name__)
)
)
[docs]
def close(self):
"""Closes any resources that should be released.
Closes the rendering window, subprocesses, network connections,
or any other resources that should be released.
"""
pass
def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
"""Takes in agent and returns the observation space for that agent.
MUST return the same value for the same agent name
Default implementation is to return the observation_spaces dict
"""
warnings.warn(
"Your environment should override the observation_space function. Attempting to use the observation_spaces dict attribute."
)
return self.observation_spaces[agent]
def action_space(self, agent: AgentID) -> gymnasium.spaces.Space:
"""Takes in agent and returns the action space for that agent.
MUST return the same value for the same agent name
Default implementation is to return the action_spaces dict
"""
warnings.warn(
"Your environment should override the action_space function. Attempting to use the action_spaces dict attribute."
)
return self.action_spaces[agent]
@property
def num_agents(self) -> int:
return len(self.agents)
@property
def max_num_agents(self) -> int:
return len(self.possible_agents)
def _deads_step_first(self) -> AgentID:
"""Makes .agent_selection point to first terminated agent.
Stores old value of agent_selection so that _was_dead_step can restore the variable after the dead agent steps.
"""
_deads_order = [
agent
for agent in self.agents
if (self.terminations[agent] or self.truncations[agent])
]
if _deads_order:
self._skip_agent_selection = self.agent_selection
self.agent_selection = _deads_order[0]
return self.agent_selection
def _clear_rewards(self) -> None:
"""Clears all items in .rewards."""
for agent in self.rewards:
self.rewards[agent] = 0
def _accumulate_rewards(self) -> None:
"""Adds .rewards dictionary to ._cumulative_rewards dictionary.
Typically called near the end of a step() method
"""
for agent, reward in self.rewards.items():
self._cumulative_rewards[agent] += reward
def agent_iter(self, max_iter: int = 2**63) -> AECIterable:
"""Yields the current agent (self.agent_selection).
Needs to be used in a loop where you step() each iteration.
"""
return AECIterable(self, max_iter)
def last(
self, observe: bool = True
) -> tuple[ObsType | None, float, bool, bool, dict[str, Any]]:
"""Returns observation, cumulative reward, terminated, truncated, info for the current agent (specified by self.agent_selection)."""
agent = self.agent_selection
assert agent is not None
observation = self.observe(agent) if observe else None
return (
observation,
self._cumulative_rewards[agent],
self.terminations[agent],
self.truncations[agent],
self.infos[agent],
)
def _was_dead_step(self, action: ActionType) -> None:
"""Helper function that performs step() for dead agents.
Does the following:
1. Removes dead agent from .agents, .terminations, .truncations, .rewards, ._cumulative_rewards, and .infos
2. Loads next agent into .agent_selection: if another agent is dead, loads that one, otherwise load next live agent
3. Clear the rewards dict
Examples:
Highly recommended to use at the beginning of step as follows:
def step(self, action):
if (self.terminations[self.agent_selection] or self.truncations[self.agent_selection]):
self._was_dead_step()
return
# main contents of step
"""
if action is not None:
raise ValueError("when an agent is dead, the only valid action is None")
# removes dead agent
agent = self.agent_selection
assert (
self.terminations[agent] or self.truncations[agent]
), "an agent that was not dead as attempted to be removed"
del self.terminations[agent]
del self.truncations[agent]
del self.rewards[agent]
del self._cumulative_rewards[agent]
del self.infos[agent]
self.agents.remove(agent)
# finds next dead agent or loads next live agent (Stored in _skip_agent_selection)
_deads_order = [
agent
for agent in self.agents
if (self.terminations[agent] or self.truncations[agent])
]
if _deads_order:
if getattr(self, "_skip_agent_selection", None) is None:
self._skip_agent_selection = self.agent_selection
self.agent_selection = _deads_order[0]
else:
if getattr(self, "_skip_agent_selection", None) is not None:
assert self._skip_agent_selection is not None
self.agent_selection = self._skip_agent_selection
self._skip_agent_selection = None
self._clear_rewards()
def __str__(self) -> str:
"""Returns a name which looks like: `space_invaders_v1`."""
if hasattr(self, "metadata"):
return self.metadata.get("name", self.__class__.__name__)
else:
return self.__class__.__name__
@property
def unwrapped(self) -> AECEnv[AgentID, ObsType, ActionType]:
return self
class AECIterable(Iterable[AgentID], Generic[AgentID, ObsType, ActionType]):
def __init__(self, env, max_iter):
self.env = env
self.max_iter = max_iter
def __iter__(self) -> AECIterator[AgentID, ObsType, ActionType]:
return AECIterator(self.env, self.max_iter)
class AECIterator(Iterator[AgentID], Generic[AgentID, ObsType, ActionType]):
def __init__(self, env: AECEnv[AgentID, ObsType, ActionType], max_iter: int):
self.env = env
self.iters_til_term = max_iter
def __next__(self) -> AgentID:
if not self.env.agents or self.iters_til_term <= 0:
raise StopIteration
self.iters_til_term -= 1
return self.env.agent_selection
def __iter__(self) -> AECIterator[AgentID, ObsType, ActionType]:
return self
[docs]
class ParallelEnv(Generic[AgentID, ObsType, ActionType]):
"""Parallel environment class.
It steps every live agent at once. If you are unsure if you
have implemented a ParallelEnv correctly, try running the `parallel_api_test` in
the Developer documentation on the website.
"""
metadata: dict[str, Any]
agents: list[AgentID]
possible_agents: list[AgentID]
observation_spaces: dict[
AgentID, gymnasium.spaces.Space
] # Observation space for each agent
action_spaces: dict[AgentID, gymnasium.spaces.Space]
[docs]
def reset(
self,
seed: int | None = None,
options: dict | None = None,
) -> tuple[dict[AgentID, ObsType], dict[AgentID, dict]]:
"""Resets the environment.
And returns a dictionary of observations (keyed by the agent name)
"""
raise NotImplementedError
[docs]
def step(
self, actions: dict[AgentID, ActionType]
) -> tuple[
dict[AgentID, ObsType],
dict[AgentID, float],
dict[AgentID, bool],
dict[AgentID, bool],
dict[AgentID, dict],
]:
"""Receives a dictionary of actions keyed by the agent name.
Returns the observation dictionary, reward dictionary, terminated dictionary, truncated dictionary
and info dictionary, where each dictionary is keyed by the agent.
"""
raise NotImplementedError
[docs]
def render(self) -> None | np.ndarray | str | list:
"""Displays a rendered frame from the environment, if supported.
Alternate render modes in the default environments are `'rgb_array'`
which returns a numpy array and is supported by all environments outside
of classic, and `'ansi'` which returns the strings printed
(specific to classic environments).
"""
raise NotImplementedError
[docs]
def close(self):
"""Closes the rendering window."""
pass
[docs]
def state(self) -> np.ndarray:
"""Returns the state.
State returns a global view of the environment appropriate for
centralized training decentralized execution methods like QMIX
"""
raise NotImplementedError(
"state() method has not been implemented in the environment {}.".format(
self.metadata.get("name", self.__class__.__name__)
)
)
[docs]
def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
"""Takes in agent and returns the observation space for that agent.
MUST return the same value for the same agent name
Default implementation is to return the observation_spaces dict
"""
warnings.warn(
"Your environment should override the observation_space function. Attempting to use the observation_spaces dict attribute."
)
return self.observation_spaces[agent]
[docs]
def action_space(self, agent: AgentID) -> gymnasium.spaces.Space:
"""Takes in agent and returns the action space for that agent.
MUST return the same value for the same agent name
Default implementation is to return the action_spaces dict
"""
warnings.warn(
"Your environment should override the action_space function. Attempting to use the action_spaces dict attribute."
)
return self.action_spaces[agent]
@property
def num_agents(self) -> int:
return len(self.agents)
@property
def max_num_agents(self) -> int:
return len(self.possible_agents)
def __str__(self) -> str:
"""Returns the name.
Which looks like: "space_invaders_v1" by default
"""
if hasattr(self, "metadata"):
return self.metadata.get("name", self.__class__.__name__)
else:
return self.__class__.__name__
@property
def unwrapped(self) -> ParallelEnv:
return self