Source code for pettingzoo.utils.conversions

# pyright: reportGeneralTypeIssues=false
import copy
import warnings
from collections import defaultdict
from typing import Callable, Dict, Optional

from pettingzoo.utils import agent_selector
from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType, ParallelEnv
from pettingzoo.utils.wrappers import OrderEnforcingWrapper


def parallel_wrapper_fn(env_fn: Callable) -> Callable:
    def par_fn(**kwargs):
        env = env_fn(**kwargs)
        env = aec_to_parallel_wrapper(env)
        return env

    return par_fn


def aec_wrapper_fn(par_env_fn: Callable) -> Callable:
    """Converts class(pettingzoo.utils.env.ParallelEnv) -> class(pettingzoo.utils.env.AECEnv).

    Args:
        par_env_fn: The class to be wrapped.

    Example:
        class my_par_class(pettingzoo.utils.env.ParallelEnv):
            ...

        my_aec_class = aec_wrapper_fn(my_par_class)

    Note: applies the `OrderEnforcingWrapper` wrapper
    """

    def aec_fn(**kwargs):
        par_env = par_env_fn(**kwargs)
        aec_env = parallel_to_aec(par_env)
        return aec_env

    return aec_fn


[docs] def aec_to_parallel( aec_env: AECEnv[AgentID, ObsType, ActionType] ) -> ParallelEnv[AgentID, ObsType, ActionType]: """Converts an AEC environment to a Parallel environment. In the case of an existing Parallel environment wrapped using a `parallel_to_aec_wrapper`, this function will return the original Parallel environment. Otherwise, it will apply the `aec_to_parallel_wrapper` to convert the environment. """ if isinstance(aec_env, OrderEnforcingWrapper) and isinstance( aec_env.env, parallel_to_aec_wrapper ): return aec_env.env.env else: par_env = aec_to_parallel_wrapper(aec_env) return par_env
[docs] def parallel_to_aec( par_env: ParallelEnv[AgentID, ObsType, Optional[ActionType]] ) -> AECEnv[AgentID, ObsType, Optional[ActionType]]: """Converts a Parallel environment to an AEC environment. In the case of an existing AEC environment wrapped using a `aec_to_parallel_wrapper`, this function will return the original AEC environment. Otherwise, it will apply the `parallel_to_aec_wrapper` to convert the environment. """ if isinstance(par_env, aec_to_parallel_wrapper): return par_env.aec_env else: aec_env = parallel_to_aec_wrapper(par_env) ordered_env = OrderEnforcingWrapper(aec_env) return ordered_env
def turn_based_aec_to_parallel( aec_env: AECEnv[AgentID, ObsType, Optional[ActionType]] ) -> ParallelEnv[AgentID, ObsType, Optional[ActionType]]: if isinstance(aec_env, parallel_to_aec_wrapper): return aec_env.env else: par_env = turn_based_aec_to_parallel_wrapper(aec_env) return par_env def to_parallel( aec_env: AECEnv[AgentID, ObsType, ActionType] ) -> ParallelEnv[AgentID, ObsType, ActionType]: warnings.warn( "The `to_parallel` function is deprecated. Use the `aec_to_parallel` function instead." ) return aec_to_parallel(aec_env) def from_parallel( par_env: ParallelEnv[AgentID, ObsType, Optional[ActionType]] ) -> AECEnv[AgentID, ObsType, Optional[ActionType]]: warnings.warn( "The `from_parallel` function is deprecated. Use the `parallel_to_aec` function instead." ) return parallel_to_aec(par_env) class aec_to_parallel_wrapper(ParallelEnv[AgentID, ObsType, ActionType]): """Converts an AEC environment into a Parallel environment.""" def __init__(self, aec_env): assert aec_env.metadata.get("is_parallelizable", False), ( "Converting from an AEC environment to a Parallel environment " "with the to_parallel wrapper is not generally safe " "(the AEC environment should only update once at the end " "of each cycle). If you have confirmed that your AEC environment " "can be converted in this way, then please set the `is_parallelizable` " "key in your metadata to True" ) self.aec_env = aec_env try: self.possible_agents = aec_env.possible_agents except AttributeError: pass self.metadata = aec_env.metadata try: self.render_mode = ( self.aec_env.render_mode # pyright: ignore[reportGeneralTypeIssues] ) except AttributeError: warnings.warn( f"The base environment `{aec_env}` does not have a `render_mode` defined." ) # Not every environment has the .state_space attribute implemented try: self.state_space = self.aec_env.state_space except AttributeError: pass @property def observation_spaces(self): warnings.warn( "The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead." ) try: return { agent: self.observation_space(agent) for agent in self.possible_agents } except AttributeError as e: raise AttributeError( "The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead" ) from e @property def action_spaces(self): warnings.warn( "The `action_spaces` dictionary is deprecated. Use the `action_space` function instead." ) try: return {agent: self.action_space(agent) for agent in self.possible_agents} except AttributeError as e: raise AttributeError( "The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead" ) from e def observation_space(self, agent): return self.aec_env.observation_space(agent) def action_space(self, agent): return self.aec_env.action_space(agent) @property def unwrapped(self): return self.aec_env.unwrapped def reset(self, seed=None, options=None): self.aec_env.reset(seed=seed, options=options) self.agents = self.aec_env.agents[:] observations = { agent: self.aec_env.observe(agent) for agent in self.aec_env.agents if not (self.aec_env.terminations[agent] or self.aec_env.truncations[agent]) } infos = dict(**self.aec_env.infos) return observations, infos def step(self, actions): rewards = defaultdict(int) terminations = {} truncations = {} infos = {} observations = {} for agent in self.aec_env.agents: if agent != self.aec_env.agent_selection: if self.aec_env.terminations[agent] or self.aec_env.truncations[agent]: raise AssertionError( f"expected agent {agent} got termination or truncation agent {self.aec_env.agent_selection}. Parallel environment wrapper expects all agent death (setting an agent's self.terminations or self.truncations entry to True) to happen only at the end of a cycle." ) else: raise AssertionError( f"expected agent {agent} got agent {self.aec_env.agent_selection}, Parallel environment wrapper expects agents to step in a cycle." ) obs, rew, termination, truncation, info = self.aec_env.last() self.aec_env.step(actions[agent]) for agent in self.aec_env.agents: rewards[agent] += self.aec_env.rewards[agent] terminations = dict(**self.aec_env.terminations) truncations = dict(**self.aec_env.truncations) infos = dict(**self.aec_env.infos) observations = { agent: self.aec_env.observe(agent) for agent in self.aec_env.agents } while self.aec_env.agents and ( self.aec_env.terminations[self.aec_env.agent_selection] or self.aec_env.truncations[self.aec_env.agent_selection] ): self.aec_env.step(None) self.agents = self.aec_env.agents return observations, rewards, terminations, truncations, infos def render(self): return self.aec_env.render() def state(self): return self.aec_env.state() def close(self): return self.aec_env.close() class parallel_to_aec_wrapper(AECEnv[AgentID, ObsType, Optional[ActionType]]): """Converts a Parallel environment into an AEC environment.""" def __init__( self, parallel_env: ParallelEnv[AgentID, ObsType, Optional[ActionType]] ): self.env = parallel_env self.metadata = {**parallel_env.metadata} self.metadata["is_parallelizable"] = True try: self.render_mode = ( self.env.render_mode # pyright: ignore[reportGeneralTypeIssues] ) except AttributeError: warnings.warn( f"The base environment `{parallel_env}` does not have a `render_mode` defined." ) try: self.possible_agents = parallel_env.possible_agents except AttributeError: pass # Not every environment has the .state_space attribute implemented try: self.state_space = ( self.env.state_space # pyright: ignore[reportGeneralTypeIssues] ) except AttributeError: pass @property def unwrapped(self): return self.env.unwrapped @property def observation_spaces(self): warnings.warn( "The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead." ) try: return { agent: self.observation_space(agent) for agent in self.possible_agents } except AttributeError as e: raise AttributeError( "The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead" ) from e @property def action_spaces(self): warnings.warn( "The `action_spaces` dictionary is deprecated. Use the `action_space` function instead." ) try: return {agent: self.action_space(agent) for agent in self.possible_agents} except AttributeError as e: raise AttributeError( "The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead" ) from e def observation_space(self, agent): return self.env.observation_space(agent) def action_space(self, agent): return self.env.action_space(agent) def reset(self, seed=None, options=None): self._observations, self.infos = self.env.reset(seed=seed, options=options) self.agents = self.env.agents[:] self._live_agents = self.agents[:] self._actions: Dict[AgentID, Optional[ActionType]] = { agent: None for agent in self.agents } self._agent_selector = agent_selector(self._live_agents) self.agent_selection = self._agent_selector.reset() self.terminations = {agent: False for agent in self.agents} self.truncations = {agent: False for agent in self.agents} self.rewards = {agent: 0 for agent in self.agents} # Every environment needs to return infos that contain self.agents as their keys if not self.infos: warnings.warn( "The `infos` dictionary returned by `env.reset` was empty. OverwritingAgent IDs will be used as keys" ) self.infos = {agent: {} for agent in self.agents} elif set(self.infos.keys()) != set(self.agents): self.infos = {agent: {self.infos.copy()} for agent in self.agents} warnings.warn( f"The `infos` dictionary returned by `env.reset()` is not valid: must contain keys for each agent defined in self.agents: {self.agents}. Overwriting with current info duplicated for each agent: {self.infos}" ) self._cumulative_rewards = {agent: 0 for agent in self.agents} self.new_agents = [] self.new_values = {} def observe(self, agent): return self._observations[agent] def state(self): return self.env.state() def add_new_agent(self, new_agent): self._agent_selector._current_agent = len(self._agent_selector.agent_order) self._agent_selector.agent_order.append(new_agent) self.agent_selection = self._agent_selector.next() self.agents.append(new_agent) self.terminations[new_agent] = False self.truncations[new_agent] = False self.infos[new_agent] = {} self.rewards[new_agent] = 0 self._cumulative_rewards[new_agent] = 0 def step(self, action: Optional[ActionType]): if ( self.terminations[self.agent_selection] or self.truncations[self.agent_selection] ): del self._actions[self.agent_selection] assert action is None self._was_dead_step(action) return self._actions[self.agent_selection] = action if self._agent_selector.is_last(): obss, rews, terminations, truncations, infos = self.env.step(self._actions) self._observations = copy.copy(obss) self.terminations = copy.copy(terminations) self.truncations = copy.copy(truncations) self.infos = copy.copy(infos) self.rewards = copy.copy(rews) self._cumulative_rewards = copy.copy(rews) env_agent_set = set(self.env.agents) self.agents = self.env.agents + [ agent for agent in sorted(self._observations.keys(), key=lambda x: str(x)) if agent not in env_agent_set ] if len(self.env.agents): self._agent_selector = agent_selector(self.env.agents) self.agent_selection = self._agent_selector.reset() self._deads_step_first() else: if self._agent_selector.is_first(): self._clear_rewards() self.agent_selection = self._agent_selector.next() def last(self, observe=True): agent = self.agent_selection observation = self.observe(agent) if observe else None return ( observation, self._cumulative_rewards[agent], self.terminations[agent], self.truncations[agent], self.infos[agent], ) def render(self): return self.env.render() def close(self): self.env.close() def __str__(self): return str(self.env) class turn_based_aec_to_parallel_wrapper( ParallelEnv[AgentID, ObsType, Optional[ActionType]] ): def __init__(self, aec_env: AECEnv[AgentID, ObsType, Optional[ActionType]]): self.aec_env = aec_env try: self.possible_agents = aec_env.possible_agents except AttributeError: pass self.metadata = aec_env.metadata # Not every environment has the .state_space attribute implemented try: self.state_space = ( self.aec_env.state_space # pyright: ignore[reportGeneralTypeIssues] ) except AttributeError: pass try: self.render_mode = ( self.aec_env.render_mode # pyright: ignore[reportGeneralTypeIssues] ) except AttributeError: warnings.warn( f"The base environment `{aec_env}` does not have a `render_mode` defined." ) @property def unwrapped(self): return self.aec_env.unwrapped @property def observation_spaces(self): warnings.warn( "The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead." ) try: return { agent: self.observation_space(agent) for agent in self.possible_agents } except AttributeError as e: raise AttributeError( "The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead" ) from e @property def action_spaces(self): warnings.warn( "The `action_spaces` dictionary is deprecated. Use the `action_space` function instead." ) try: return {agent: self.action_space(agent) for agent in self.possible_agents} except AttributeError as e: raise AttributeError( "The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead" ) from e def observation_space(self, agent): return self.aec_env.observation_space(agent) def action_space(self, agent): return self.aec_env.action_space(agent) def reset(self, seed=None, options=None): self.aec_env.reset(seed=seed, options=options) self.agents = self.aec_env.agents[:] observations = { agent: self.aec_env.observe(agent) for agent in self.aec_env.agents if not (self.aec_env.terminations[agent] or self.aec_env.truncations[agent]) } infos = {**self.aec_env.infos} return observations, infos def step(self, actions): if not self.agents: return {}, {}, {}, {} self.aec_env.step(actions[self.aec_env.agent_selection]) rewards = {**self.aec_env.rewards} terminations = {**self.aec_env.terminations} truncations = {**self.aec_env.truncations} infos = {**self.aec_env.infos} observations = { agent: self.aec_env.observe(agent) for agent in self.aec_env.agents } while self.aec_env.agents: if ( self.aec_env.terminations[self.aec_env.agent_selection] or self.aec_env.truncations[self.aec_env.agent_selection] ): self.aec_env.step(None) else: break # no need to update data after null step (nothing should change other than the active agent) for agent in self.aec_env.agents: infos[agent]["active_agent"] = self.aec_env.agent_selection self.agents = self.aec_env.agents return observations, rewards, terminations, truncations, infos def render(self): return self.aec_env.render() def state(self): return self.aec_env.state() def close(self): return self.aec_env.close()