Source code for pettingzoo.utils.conversions

# pyright: reportGeneralTypeIssues=false
import copy
import warnings
from collections import defaultdict
from typing import Callable, Dict, Optional

from pettingzoo.utils import AgentSelector
from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType, ParallelEnv
from pettingzoo.utils.wrappers import OrderEnforcingWrapper


def parallel_wrapper_fn(env_fn: Callable) -> Callable:
    def par_fn(**kwargs):
        env = env_fn(**kwargs)
        env = aec_to_parallel_wrapper(env)
        return env

    return par_fn


def aec_wrapper_fn(par_env_fn: Callable) -> Callable:
    """Converts class(pettingzoo.utils.env.ParallelEnv) -> class(pettingzoo.utils.env.AECEnv).

    Args:
        par_env_fn: The class to be wrapped.

    Example:
        class my_par_class(pettingzoo.utils.env.ParallelEnv):
            ...

        my_aec_class = aec_wrapper_fn(my_par_class)

    Note: applies the `OrderEnforcingWrapper` wrapper
    """

    def aec_fn(**kwargs):
        par_env = par_env_fn(**kwargs)
        aec_env = parallel_to_aec(par_env)
        return aec_env

    return aec_fn


[docs] def aec_to_parallel( aec_env: AECEnv[AgentID, ObsType, ActionType] ) -> ParallelEnv[AgentID, ObsType, ActionType]: """Converts an AEC environment to a Parallel environment. In the case of an existing Parallel environment wrapped using a `parallel_to_aec_wrapper`, this function will return the original Parallel environment. Otherwise, it will apply the `aec_to_parallel_wrapper` to convert the environment. """ if isinstance(aec_env, OrderEnforcingWrapper) and isinstance( aec_env.env, parallel_to_aec_wrapper ): return aec_env.env.env else: par_env = aec_to_parallel_wrapper(aec_env) return par_env
[docs] def parallel_to_aec( par_env: ParallelEnv[AgentID, ObsType, Optional[ActionType]] ) -> AECEnv[AgentID, ObsType, Optional[ActionType]]: """Converts a Parallel environment to an AEC environment. In the case of an existing AEC environment wrapped using a `aec_to_parallel_wrapper`, this function will return the original AEC environment. Otherwise, it will apply the `parallel_to_aec_wrapper` to convert the environment. """ if isinstance(par_env, aec_to_parallel_wrapper): return par_env.aec_env else: aec_env = parallel_to_aec_wrapper(par_env) ordered_env = OrderEnforcingWrapper(aec_env) return ordered_env
def turn_based_aec_to_parallel( aec_env: AECEnv[AgentID, ObsType, Optional[ActionType]] ) -> ParallelEnv[AgentID, ObsType, Optional[ActionType]]: if isinstance(aec_env, parallel_to_aec_wrapper): return aec_env.env else: par_env = turn_based_aec_to_parallel_wrapper(aec_env) return par_env def to_parallel( aec_env: AECEnv[AgentID, ObsType, ActionType] ) -> ParallelEnv[AgentID, ObsType, ActionType]: warnings.warn( "The `to_parallel` function is deprecated. Use the `aec_to_parallel` function instead." ) return aec_to_parallel(aec_env) def from_parallel( par_env: ParallelEnv[AgentID, ObsType, Optional[ActionType]] ) -> AECEnv[AgentID, ObsType, Optional[ActionType]]: warnings.warn( "The `from_parallel` function is deprecated. Use the `parallel_to_aec` function instead." ) return parallel_to_aec(par_env) class aec_to_parallel_wrapper(ParallelEnv[AgentID, ObsType, ActionType]): """Converts an AEC environment into a Parallel environment.""" def __init__(self, aec_env): assert aec_env.metadata.get("is_parallelizable", False), ( "Converting from an AEC environment to a Parallel environment " "with the to_parallel wrapper is not generally safe " "(the AEC environment should only update once at the end " "of each cycle). If you have confirmed that your AEC environment " "can be converted in this way, then please set the `is_parallelizable` " "key in your metadata to True" ) self.aec_env = aec_env try: self.possible_agents = aec_env.possible_agents except AttributeError: pass self.metadata = aec_env.metadata try: self.render_mode = ( self.aec_env.render_mode # pyright: ignore[reportGeneralTypeIssues] ) except AttributeError: warnings.warn( f"The base environment `{aec_env}` does not have a `render_mode` defined." ) # Not every environment has the .state_space attribute implemented try: self.state_space = self.aec_env.state_space except AttributeError: pass @property def observation_spaces(self): warnings.warn( "The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead." ) try: return { agent: self.observation_space(agent) for agent in self.possible_agents } except AttributeError as e: raise AttributeError( "The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead" ) from e @property def action_spaces(self): warnings.warn( "The `action_spaces` dictionary is deprecated. Use the `action_space` function instead." ) try: return {agent: self.action_space(agent) for agent in self.possible_agents} except AttributeError as e: raise AttributeError( "The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead" ) from e def observation_space(self, agent): return self.aec_env.observation_space(agent) def action_space(self, agent): return self.aec_env.action_space(agent) @property def unwrapped(self): return self.aec_env.unwrapped def reset(self, seed=None, options=None): self.aec_env.reset(seed=seed, options=options) self.agents = self.aec_env.agents[:] observations = { agent: self.aec_env.observe(agent) for agent in self.aec_env.agents if not (self.aec_env.terminations[agent] or self.aec_env.truncations[agent]) } infos = dict(**self.aec_env.infos) return observations, infos def step(self, actions): rewards = defaultdict(int) terminations = {} truncations = {} infos = {} observations = {} for agent in self.aec_env.agents: if agent != self.aec_env.agent_selection: if self.aec_env.terminations[agent] or self.aec_env.truncations[agent]: raise AssertionError( f"expected agent {agent} got termination or truncation agent {self.aec_env.agent_selection}. Parallel environment wrapper expects all agent death (setting an agent's self.terminations or self.truncations entry to True) to happen only at the end of a cycle." ) else: raise AssertionError( f"expected agent {agent} got agent {self.aec_env.agent_selection}, Parallel environment wrapper expects agents to step in a cycle." ) obs, rew, termination, truncation, info = self.aec_env.last() self.aec_env.step(actions[agent]) for agent in self.aec_env.agents: rewards[agent] += self.aec_env.rewards[agent] terminations = dict(**self.aec_env.terminations) truncations = dict(**self.aec_env.truncations) infos = dict(**self.aec_env.infos) observations = { agent: self.aec_env.observe(agent) for agent in self.aec_env.agents } while self.aec_env.agents and ( self.aec_env.terminations[self.aec_env.agent_selection] or self.aec_env.truncations[self.aec_env.agent_selection] ): self.aec_env.step(None) self.agents = self.aec_env.agents return observations, rewards, terminations, truncations, infos def render(self): return self.aec_env.render() def state(self): return self.aec_env.state() def close(self): return self.aec_env.close() class parallel_to_aec_wrapper(AECEnv[AgentID, ObsType, Optional[ActionType]]): """Converts a Parallel environment into an AEC environment.""" def __init__( self, parallel_env: ParallelEnv[AgentID, ObsType, Optional[ActionType]] ): self.env = parallel_env self.metadata = {**parallel_env.metadata} self.metadata["is_parallelizable"] = True try: self.render_mode = ( self.env.render_mode # pyright: ignore[reportGeneralTypeIssues] ) except AttributeError: warnings.warn( f"The base environment `{parallel_env}` does not have a `render_mode` defined." ) try: self.possible_agents = parallel_env.possible_agents except AttributeError: pass # Not every environment has the .state_space attribute implemented try: self.state_space = ( self.env.state_space # pyright: ignore[reportGeneralTypeIssues] ) except AttributeError: pass @property def unwrapped(self): return self.env.unwrapped @property def observation_spaces(self): warnings.warn( "The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead." ) try: return { agent: self.observation_space(agent) for agent in self.possible_agents } except AttributeError as e: raise AttributeError( "The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead" ) from e @property def action_spaces(self): warnings.warn( "The `action_spaces` dictionary is deprecated. Use the `action_space` function instead." ) try: return {agent: self.action_space(agent) for agent in self.possible_agents} except AttributeError as e: raise AttributeError( "The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead" ) from e def observation_space(self, agent): return self.env.observation_space(agent) def action_space(self, agent): return self.env.action_space(agent) def reset(self, seed=None, options=None): self._observations, self.infos = self.env.reset(seed=seed, options=options) self.agents = self.env.agents[:] self._live_agents = self.agents[:] self._actions: Dict[AgentID, Optional[ActionType]] = { agent: None for agent in self.agents } self._agent_selector = AgentSelector(self._live_agents) self.agent_selection = self._agent_selector.reset() self.terminations = {agent: False for agent in self.agents} self.truncations = {agent: False for agent in self.agents} self.rewards = {agent: 0 for agent in self.agents} # Every environment needs to return infos that contain self.agents as their keys if not self.infos: warnings.warn( "The `infos` dictionary returned by `env.reset` was empty. OverwritingAgent IDs will be used as keys" ) self.infos = {agent: {} for agent in self.agents} elif set(self.infos.keys()) != set(self.agents): self.infos = {agent: {self.infos.copy()} for agent in self.agents} warnings.warn( f"The `infos` dictionary returned by `env.reset()` is not valid: must contain keys for each agent defined in self.agents: {self.agents}. Overwriting with current info duplicated for each agent: {self.infos}" ) self._cumulative_rewards = {agent: 0 for agent in self.agents} self.new_agents = [] self.new_values = {} def observe(self, agent): return self._observations[agent] def state(self): return self.env.state() def add_new_agent(self, new_agent): self._agent_selector._current_agent = len(self._agent_selector.agent_order) self._agent_selector.agent_order.append(new_agent) self.agent_selection = self._agent_selector.next() self.agents.append(new_agent) self.terminations[new_agent] = False self.truncations[new_agent] = False self.infos[new_agent] = {} self.rewards[new_agent] = 0 self._cumulative_rewards[new_agent] = 0 def step(self, action: Optional[ActionType]): if ( self.terminations[self.agent_selection] or self.truncations[self.agent_selection] ): del self._actions[self.agent_selection] assert action is None self._was_dead_step(action) return self._actions[self.agent_selection] = action if self._agent_selector.is_last(): obss, rews, terminations, truncations, infos = self.env.step(self._actions) self._observations = copy.copy(obss) self.terminations = copy.copy(terminations) self.truncations = copy.copy(truncations) self.infos = copy.copy(infos) self.rewards = copy.copy(rews) self._cumulative_rewards = copy.copy(rews) env_agent_set = set(self.env.agents) self.agents = self.env.agents + [ agent for agent in sorted(self._observations.keys(), key=lambda x: str(x)) if agent not in env_agent_set ] if len(self.env.agents): self._agent_selector = AgentSelector(self.env.agents) self.agent_selection = self._agent_selector.reset() self._deads_step_first() else: if self._agent_selector.is_first(): self._clear_rewards() self.agent_selection = self._agent_selector.next() def last(self, observe=True): agent = self.agent_selection observation = self.observe(agent) if observe else None return ( observation, self._cumulative_rewards[agent], self.terminations[agent], self.truncations[agent], self.infos[agent], ) def render(self): return self.env.render() def close(self): self.env.close() def __str__(self): return str(self.env) class turn_based_aec_to_parallel_wrapper( ParallelEnv[AgentID, ObsType, Optional[ActionType]] ): def __init__(self, aec_env: AECEnv[AgentID, ObsType, Optional[ActionType]]): self.aec_env = aec_env try: self.possible_agents = aec_env.possible_agents except AttributeError: pass self.metadata = aec_env.metadata # Not every environment has the .state_space attribute implemented try: self.state_space = ( self.aec_env.state_space # pyright: ignore[reportGeneralTypeIssues] ) except AttributeError: pass try: self.render_mode = ( self.aec_env.render_mode # pyright: ignore[reportGeneralTypeIssues] ) except AttributeError: warnings.warn( f"The base environment `{aec_env}` does not have a `render_mode` defined." ) @property def unwrapped(self): return self.aec_env.unwrapped @property def observation_spaces(self): warnings.warn( "The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead." ) try: return { agent: self.observation_space(agent) for agent in self.possible_agents } except AttributeError as e: raise AttributeError( "The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead" ) from e @property def action_spaces(self): warnings.warn( "The `action_spaces` dictionary is deprecated. Use the `action_space` function instead." ) try: return {agent: self.action_space(agent) for agent in self.possible_agents} except AttributeError as e: raise AttributeError( "The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead" ) from e def observation_space(self, agent): return self.aec_env.observation_space(agent) def action_space(self, agent): return self.aec_env.action_space(agent) def reset(self, seed=None, options=None): self.aec_env.reset(seed=seed, options=options) self.agents = self.aec_env.agents[:] observations = { agent: self.aec_env.observe(agent) for agent in self.aec_env.agents if not (self.aec_env.terminations[agent] or self.aec_env.truncations[agent]) } infos = {**self.aec_env.infos} return observations, infos def step(self, actions): if not self.agents: return {}, {}, {}, {} self.aec_env.step(actions[self.aec_env.agent_selection]) rewards = {**self.aec_env.rewards} terminations = {**self.aec_env.terminations} truncations = {**self.aec_env.truncations} infos = {**self.aec_env.infos} observations = { agent: self.aec_env.observe(agent) for agent in self.aec_env.agents } while self.aec_env.agents: if ( self.aec_env.terminations[self.aec_env.agent_selection] or self.aec_env.truncations[self.aec_env.agent_selection] ): self.aec_env.step(None) else: break # no need to update data after null step (nothing should change other than the active agent) for agent in self.aec_env.agents: infos[agent]["active_agent"] = self.aec_env.agent_selection self.agents = self.aec_env.agents return observations, rewards, terminations, truncations, infos def render(self): return self.aec_env.render() def state(self): return self.aec_env.state() def close(self): return self.aec_env.close()