Tutorial: Action Masking¶
Introduction¶
In many environments, it is natural for some actions to be invalid at certain times. For example, in a game of chess, it is impossible to move a pawn forward if it is already at the front of the board. In PettingZoo, we can use action masking to prevent invalid actions from being taken.
Action masking is a more natural way of handling invalid actions than having an action have no effect, which was how we handled bumping into walls in the previous tutorial.
Code¶
import functools
import random
from copy import copy
import numpy as np
from gymnasium.spaces import Discrete, MultiDiscrete
from pettingzoo import ParallelEnv
class CustomActionMaskedEnvironment(ParallelEnv):
"""The metadata holds environment constants.
The "name" metadata allows the environment to be pretty printed.
"""
metadata = {
"name": "custom_environment_v0",
}
def __init__(self):
"""The init method takes in environment arguments.
Should define the following attributes:
- escape x and y coordinates
- guard x and y coordinates
- prisoner x and y coordinates
- timestamp
- possible_agents
Note: as of v1.18.1, the action_spaces and observation_spaces attributes are deprecated.
Spaces should be defined in the action_space() and observation_space() methods.
If these methods are not overridden, spaces will be inferred from self.observation_spaces/action_spaces, raising a warning.
These attributes should not be changed after initialization.
"""
self.escape_y = None
self.escape_x = None
self.guard_y = None
self.guard_x = None
self.prisoner_y = None
self.prisoner_x = None
self.timestep = None
self.possible_agents = ["prisoner", "guard"]
def reset(self, seed=None, options=None):
"""Reset set the environment to a starting point.
It needs to initialize the following attributes:
- agents
- timestamp
- prisoner x and y coordinates
- guard x and y coordinates
- escape x and y coordinates
- observation
- infos
And must set up the environment so that render(), step(), and observe() can be called without issues.
"""
self.agents = copy(self.possible_agents)
self.timestep = 0
self.prisoner_x = 0
self.prisoner_y = 0
self.guard_x = 7
self.guard_y = 7
self.escape_x = random.randint(2, 5)
self.escape_y = random.randint(2, 5)
observation = (
self.prisoner_x + 7 * self.prisoner_y,
self.guard_x + 7 * self.guard_y,
self.escape_x + 7 * self.escape_y,
)
observations = {
"prisoner": {"observation": observation, "action_mask": [0, 1, 1, 0]},
"guard": {"observation": observation, "action_mask": [1, 0, 0, 1]},
}
# Get dummy infos. Necessary for proper parallel_to_aec conversion
infos = {a: {} for a in self.agents}
return observations, infos
def step(self, actions):
"""Takes in an action for the current agent (specified by agent_selection).
Needs to update:
- prisoner x and y coordinates
- guard x and y coordinates
- terminations
- truncations
- rewards
- timestamp
- infos
And any internal state used by observe() or render()
"""
# Execute actions
prisoner_action = actions["prisoner"]
guard_action = actions["guard"]
if prisoner_action == 0 and self.prisoner_x > 0:
self.prisoner_x -= 1
elif prisoner_action == 1 and self.prisoner_x < 6:
self.prisoner_x += 1
elif prisoner_action == 2 and self.prisoner_y > 0:
self.prisoner_y -= 1
elif prisoner_action == 3 and self.prisoner_y < 6:
self.prisoner_y += 1
if guard_action == 0 and self.guard_x > 0:
self.guard_x -= 1
elif guard_action == 1 and self.guard_x < 6:
self.guard_x += 1
elif guard_action == 2 and self.guard_y > 0:
self.guard_y -= 1
elif guard_action == 3 and self.guard_y < 6:
self.guard_y += 1
# Generate action masks
prisoner_action_mask = np.ones(4, dtype=np.int8)
if self.prisoner_x == 0:
prisoner_action_mask[0] = 0 # Block left movement
elif self.prisoner_x == 6:
prisoner_action_mask[1] = 0 # Block right movement
if self.prisoner_y == 0:
prisoner_action_mask[2] = 0 # Block down movement
elif self.prisoner_y == 6:
prisoner_action_mask[3] = 0 # Block up movement
guard_action_mask = np.ones(4, dtype=np.int8)
if self.guard_x == 0:
guard_action_mask[0] = 0
elif self.guard_x == 6:
guard_action_mask[1] = 0
if self.guard_y == 0:
guard_action_mask[2] = 0
elif self.guard_y == 6:
guard_action_mask[3] = 0
# Action mask to prevent guard from going over escape cell
if self.guard_x - 1 == self.escape_x:
guard_action_mask[0] = 0
elif self.guard_x + 1 == self.escape_x:
guard_action_mask[1] = 0
if self.guard_y - 1 == self.escape_y:
guard_action_mask[2] = 0
elif self.guard_y + 1 == self.escape_y:
guard_action_mask[3] = 0
# Check termination conditions
terminations = {a: False for a in self.agents}
rewards = {a: 0 for a in self.agents}
if self.prisoner_x == self.guard_x and self.prisoner_y == self.guard_y:
rewards = {"prisoner": -1, "guard": 1}
terminations = {a: True for a in self.agents}
self.agents = []
elif self.prisoner_x == self.escape_x and self.prisoner_y == self.escape_y:
rewards = {"prisoner": 1, "guard": -1}
terminations = {a: True for a in self.agents}
self.agents = []
# Check truncation conditions (overwrites termination conditions)
truncations = {"prisoner": False, "guard": False}
if self.timestep > 100:
rewards = {"prisoner": 0, "guard": 0}
truncations = {"prisoner": True, "guard": True}
self.agents = []
self.timestep += 1
# Get observations
observation = (
self.prisoner_x + 7 * self.prisoner_y,
self.guard_x + 7 * self.guard_y,
self.escape_x + 7 * self.escape_y,
)
observations = {
"prisoner": {
"observation": observation,
"action_mask": prisoner_action_mask,
},
"guard": {"observation": observation, "action_mask": guard_action_mask},
}
# Get dummy infos (not used in this example)
infos = {"prisoner": {}, "guard": {}}
return observations, rewards, terminations, truncations, infos
def render(self):
"""Renders the environment."""
grid = np.zeros((8, 8), dtype=object)
grid[self.prisoner_y, self.prisoner_x] = "P"
grid[self.guard_y, self.guard_x] = "G"
grid[self.escape_y, self.escape_x] = "E"
print(f"{grid} \n")
# Observation space should be defined here.
# lru_cache allows observation and action spaces to be memoized, reducing clock cycles required to get each agent's space.
# If your spaces change over time, remove this line (disable caching).
@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
# gymnasium spaces are defined and documented here: https://gymnasium.farama.org/api/spaces/
return MultiDiscrete([7 * 7 - 1] * 3)
# Action space should be defined here.
# If your spaces change over time, remove this line (disable caching).
@functools.lru_cache(maxsize=None)
def action_space(self, agent):
return Discrete(4)