# pyright: reportGeneralTypeIssues=false
import copy
import warnings
from collections import defaultdict
from typing import Callable, Dict, Optional
from pettingzoo.utils import AgentSelector
from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType, ParallelEnv
from pettingzoo.utils.wrappers import OrderEnforcingWrapper
def parallel_wrapper_fn(env_fn: Callable) -> Callable:
def par_fn(**kwargs):
env = env_fn(**kwargs)
env = aec_to_parallel_wrapper(env)
return env
return par_fn
def aec_wrapper_fn(par_env_fn: Callable) -> Callable:
"""Converts class(pettingzoo.utils.env.ParallelEnv) -> class(pettingzoo.utils.env.AECEnv).
Args:
par_env_fn: The class to be wrapped.
Example:
class my_par_class(pettingzoo.utils.env.ParallelEnv):
...
my_aec_class = aec_wrapper_fn(my_par_class)
Note: applies the `OrderEnforcingWrapper` wrapper
"""
def aec_fn(**kwargs):
par_env = par_env_fn(**kwargs)
aec_env = parallel_to_aec(par_env)
return aec_env
return aec_fn
[docs]
def aec_to_parallel(
aec_env: AECEnv[AgentID, ObsType, ActionType]
) -> ParallelEnv[AgentID, ObsType, ActionType]:
"""Converts an AEC environment to a Parallel environment.
In the case of an existing Parallel environment wrapped using a `parallel_to_aec_wrapper`, this function will return the original Parallel environment.
Otherwise, it will apply the `aec_to_parallel_wrapper` to convert the environment.
"""
if isinstance(aec_env, OrderEnforcingWrapper) and isinstance(
aec_env.env, parallel_to_aec_wrapper
):
return aec_env.env.env
else:
par_env = aec_to_parallel_wrapper(aec_env)
return par_env
[docs]
def parallel_to_aec(
par_env: ParallelEnv[AgentID, ObsType, Optional[ActionType]]
) -> AECEnv[AgentID, ObsType, Optional[ActionType]]:
"""Converts a Parallel environment to an AEC environment.
In the case of an existing AEC environment wrapped using a `aec_to_parallel_wrapper`, this function will return the original AEC environment.
Otherwise, it will apply the `parallel_to_aec_wrapper` to convert the environment.
"""
if isinstance(par_env, aec_to_parallel_wrapper):
return par_env.aec_env
else:
aec_env = parallel_to_aec_wrapper(par_env)
ordered_env = OrderEnforcingWrapper(aec_env)
return ordered_env
def turn_based_aec_to_parallel(
aec_env: AECEnv[AgentID, ObsType, Optional[ActionType]]
) -> ParallelEnv[AgentID, ObsType, Optional[ActionType]]:
if isinstance(aec_env, parallel_to_aec_wrapper):
return aec_env.env
else:
par_env = turn_based_aec_to_parallel_wrapper(aec_env)
return par_env
def to_parallel(
aec_env: AECEnv[AgentID, ObsType, ActionType]
) -> ParallelEnv[AgentID, ObsType, ActionType]:
warnings.warn(
"The `to_parallel` function is deprecated. Use the `aec_to_parallel` function instead."
)
return aec_to_parallel(aec_env)
def from_parallel(
par_env: ParallelEnv[AgentID, ObsType, Optional[ActionType]]
) -> AECEnv[AgentID, ObsType, Optional[ActionType]]:
warnings.warn(
"The `from_parallel` function is deprecated. Use the `parallel_to_aec` function instead."
)
return parallel_to_aec(par_env)
class aec_to_parallel_wrapper(ParallelEnv[AgentID, ObsType, ActionType]):
"""Converts an AEC environment into a Parallel environment."""
def __init__(self, aec_env):
assert aec_env.metadata.get("is_parallelizable", False), (
"Converting from an AEC environment to a Parallel environment "
"with the to_parallel wrapper is not generally safe "
"(the AEC environment should only update once at the end "
"of each cycle). If you have confirmed that your AEC environment "
"can be converted in this way, then please set the `is_parallelizable` "
"key in your metadata to True"
)
self.aec_env = aec_env
try:
self.possible_agents = aec_env.possible_agents
except AttributeError:
pass
self.metadata = aec_env.metadata
try:
self.render_mode = (
self.aec_env.render_mode # pyright: ignore[reportGeneralTypeIssues]
)
except AttributeError:
warnings.warn(
f"The base environment `{aec_env}` does not have a `render_mode` defined."
)
# Not every environment has the .state_space attribute implemented
try:
self.state_space = self.aec_env.state_space
except AttributeError:
pass
@property
def observation_spaces(self):
warnings.warn(
"The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead."
)
try:
return {
agent: self.observation_space(agent) for agent in self.possible_agents
}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead"
) from e
@property
def action_spaces(self):
warnings.warn(
"The `action_spaces` dictionary is deprecated. Use the `action_space` function instead."
)
try:
return {agent: self.action_space(agent) for agent in self.possible_agents}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead"
) from e
def observation_space(self, agent):
return self.aec_env.observation_space(agent)
def action_space(self, agent):
return self.aec_env.action_space(agent)
@property
def unwrapped(self):
return self.aec_env.unwrapped
def reset(self, seed=None, options=None):
self.aec_env.reset(seed=seed, options=options)
self.agents = self.aec_env.agents[:]
observations = {
agent: self.aec_env.observe(agent)
for agent in self.aec_env.agents
if not (self.aec_env.terminations[agent] or self.aec_env.truncations[agent])
}
infos = dict(**self.aec_env.infos)
return observations, infos
def step(self, actions):
rewards = defaultdict(int)
terminations = {}
truncations = {}
infos = {}
observations = {}
for agent in self.aec_env.agents:
if agent != self.aec_env.agent_selection:
if self.aec_env.terminations[agent] or self.aec_env.truncations[agent]:
raise AssertionError(
f"expected agent {agent} got termination or truncation agent {self.aec_env.agent_selection}. Parallel environment wrapper expects all agent death (setting an agent's self.terminations or self.truncations entry to True) to happen only at the end of a cycle."
)
else:
raise AssertionError(
f"expected agent {agent} got agent {self.aec_env.agent_selection}, Parallel environment wrapper expects agents to step in a cycle."
)
obs, rew, termination, truncation, info = self.aec_env.last()
self.aec_env.step(actions[agent])
for agent in self.aec_env.agents:
rewards[agent] += self.aec_env.rewards[agent]
terminations = dict(**self.aec_env.terminations)
truncations = dict(**self.aec_env.truncations)
infos = dict(**self.aec_env.infos)
observations = {
agent: self.aec_env.observe(agent) for agent in self.aec_env.agents
}
while self.aec_env.agents and (
self.aec_env.terminations[self.aec_env.agent_selection]
or self.aec_env.truncations[self.aec_env.agent_selection]
):
self.aec_env.step(None)
self.agents = self.aec_env.agents
return observations, rewards, terminations, truncations, infos
def render(self):
return self.aec_env.render()
def state(self):
return self.aec_env.state()
def close(self):
return self.aec_env.close()
class parallel_to_aec_wrapper(AECEnv[AgentID, ObsType, Optional[ActionType]]):
"""Converts a Parallel environment into an AEC environment."""
def __init__(
self, parallel_env: ParallelEnv[AgentID, ObsType, Optional[ActionType]]
):
self.env = parallel_env
self.metadata = {**parallel_env.metadata}
self.metadata["is_parallelizable"] = True
try:
self.render_mode = (
self.env.render_mode # pyright: ignore[reportGeneralTypeIssues]
)
except AttributeError:
warnings.warn(
f"The base environment `{parallel_env}` does not have a `render_mode` defined."
)
try:
self.possible_agents = parallel_env.possible_agents
except AttributeError:
pass
# Not every environment has the .state_space attribute implemented
try:
self.state_space = (
self.env.state_space # pyright: ignore[reportGeneralTypeIssues]
)
except AttributeError:
pass
@property
def unwrapped(self):
return self.env.unwrapped
@property
def observation_spaces(self):
warnings.warn(
"The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead."
)
try:
return {
agent: self.observation_space(agent) for agent in self.possible_agents
}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead"
) from e
@property
def action_spaces(self):
warnings.warn(
"The `action_spaces` dictionary is deprecated. Use the `action_space` function instead."
)
try:
return {agent: self.action_space(agent) for agent in self.possible_agents}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead"
) from e
def observation_space(self, agent):
return self.env.observation_space(agent)
def action_space(self, agent):
return self.env.action_space(agent)
def reset(self, seed=None, options=None):
self._observations, self.infos = self.env.reset(seed=seed, options=options)
self.agents = self.env.agents[:]
self._live_agents = self.agents[:]
self._actions: Dict[AgentID, Optional[ActionType]] = {
agent: None for agent in self.agents
}
self._agent_selector = AgentSelector(self._live_agents)
self.agent_selection = self._agent_selector.reset()
self.terminations = {agent: False for agent in self.agents}
self.truncations = {agent: False for agent in self.agents}
self.rewards = {agent: 0 for agent in self.agents}
# Every environment needs to return infos that contain self.agents as their keys
if not self.infos:
warnings.warn(
"The `infos` dictionary returned by `env.reset` was empty. OverwritingAgent IDs will be used as keys"
)
self.infos = {agent: {} for agent in self.agents}
elif set(self.infos.keys()) != set(self.agents):
self.infos = {agent: {self.infos.copy()} for agent in self.agents}
warnings.warn(
f"The `infos` dictionary returned by `env.reset()` is not valid: must contain keys for each agent defined in self.agents: {self.agents}. Overwriting with current info duplicated for each agent: {self.infos}"
)
self._cumulative_rewards = {agent: 0 for agent in self.agents}
self.new_agents = []
self.new_values = {}
def observe(self, agent):
return self._observations[agent]
def state(self):
return self.env.state()
def add_new_agent(self, new_agent):
self._agent_selector._current_agent = len(self._agent_selector.agent_order)
self._agent_selector.agent_order.append(new_agent)
self.agent_selection = self._agent_selector.next()
self.agents.append(new_agent)
self.terminations[new_agent] = False
self.truncations[new_agent] = False
self.infos[new_agent] = {}
self.rewards[new_agent] = 0
self._cumulative_rewards[new_agent] = 0
def step(self, action: Optional[ActionType]):
if (
self.terminations[self.agent_selection]
or self.truncations[self.agent_selection]
):
del self._actions[self.agent_selection]
assert action is None
self._was_dead_step(action)
return
self._actions[self.agent_selection] = action
if self._agent_selector.is_last():
obss, rews, terminations, truncations, infos = self.env.step(self._actions)
self._observations = copy.copy(obss)
self.terminations = copy.copy(terminations)
self.truncations = copy.copy(truncations)
self.infos = copy.copy(infos)
self.rewards = copy.copy(rews)
self._cumulative_rewards = copy.copy(rews)
env_agent_set = set(self.env.agents)
self.agents = self.env.agents + [
agent
for agent in sorted(self._observations.keys(), key=lambda x: str(x))
if agent not in env_agent_set
]
if len(self.env.agents):
self._agent_selector = AgentSelector(self.env.agents)
self.agent_selection = self._agent_selector.reset()
self._deads_step_first()
else:
if self._agent_selector.is_first():
self._clear_rewards()
self.agent_selection = self._agent_selector.next()
def last(self, observe=True):
agent = self.agent_selection
observation = self.observe(agent) if observe else None
return (
observation,
self._cumulative_rewards[agent],
self.terminations[agent],
self.truncations[agent],
self.infos[agent],
)
def render(self):
return self.env.render()
def close(self):
self.env.close()
def __str__(self):
return str(self.env)
class turn_based_aec_to_parallel_wrapper(
ParallelEnv[AgentID, ObsType, Optional[ActionType]]
):
def __init__(self, aec_env: AECEnv[AgentID, ObsType, Optional[ActionType]]):
self.aec_env = aec_env
try:
self.possible_agents = aec_env.possible_agents
except AttributeError:
pass
self.metadata = aec_env.metadata
# Not every environment has the .state_space attribute implemented
try:
self.state_space = (
self.aec_env.state_space # pyright: ignore[reportGeneralTypeIssues]
)
except AttributeError:
pass
try:
self.render_mode = (
self.aec_env.render_mode # pyright: ignore[reportGeneralTypeIssues]
)
except AttributeError:
warnings.warn(
f"The base environment `{aec_env}` does not have a `render_mode` defined."
)
@property
def unwrapped(self):
return self.aec_env.unwrapped
@property
def observation_spaces(self):
warnings.warn(
"The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead."
)
try:
return {
agent: self.observation_space(agent) for agent in self.possible_agents
}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead"
) from e
@property
def action_spaces(self):
warnings.warn(
"The `action_spaces` dictionary is deprecated. Use the `action_space` function instead."
)
try:
return {agent: self.action_space(agent) for agent in self.possible_agents}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead"
) from e
def observation_space(self, agent):
return self.aec_env.observation_space(agent)
def action_space(self, agent):
return self.aec_env.action_space(agent)
def reset(self, seed=None, options=None):
self.aec_env.reset(seed=seed, options=options)
self.agents = self.aec_env.agents[:]
observations = {
agent: self.aec_env.observe(agent)
for agent in self.aec_env.agents
if not (self.aec_env.terminations[agent] or self.aec_env.truncations[agent])
}
infos = {**self.aec_env.infos}
return observations, infos
def step(self, actions):
if not self.agents:
return {}, {}, {}, {}
self.aec_env.step(actions[self.aec_env.agent_selection])
rewards = {**self.aec_env.rewards}
terminations = {**self.aec_env.terminations}
truncations = {**self.aec_env.truncations}
infos = {**self.aec_env.infos}
observations = {
agent: self.aec_env.observe(agent) for agent in self.aec_env.agents
}
while self.aec_env.agents:
if (
self.aec_env.terminations[self.aec_env.agent_selection]
or self.aec_env.truncations[self.aec_env.agent_selection]
):
self.aec_env.step(None)
else:
break
# no need to update data after null step (nothing should change other than the active agent)
for agent in self.aec_env.agents:
infos[agent]["active_agent"] = self.aec_env.agent_selection
self.agents = self.aec_env.agents
return observations, rewards, terminations, truncations, infos
def render(self):
return self.aec_env.render()
def state(self):
return self.aec_env.state()
def close(self):
return self.aec_env.close()