SB3: Action Masked PPO for Connect Four¶
This tutorial shows how to train a agents using Maskable Proximal Policy Optimization (PPO) on the Connect Four environment (AEC).
It creates a custom Wrapper to convert to a Gymnasium-like environment which is compatible with SB3 action masking.
After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3’s documentation for more information).
Note
This environment has a discrete (1-dimensional) observation space with an illegal action mask, so we use a masked MLP feature extractor.
Warning
The SB3ActionMaskWrapper wrapper assumes that the action space and observation space is the same for each agent, this assumption may not hold for custom environments.
Environment Setup¶
To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts.
pettingzoo[classic]>=1.24.0
stable-baselines3>=2.0.0
sb3-contrib>=2.0.0
Code¶
The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the Discord server.
Training and Evaluation¶
"""Uses Stable-Baselines3 to train agents in the Connect Four environment using invalid action masking.
For information about invalid action masking in PettingZoo, see https://pettingzoo.farama.org/api/aec/#action-masking
For more information about invalid action masking in SB3, see https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html
Author: Elliot (https://github.com/elliottower)
"""
import glob
import os
import time
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker
import pettingzoo.utils
from pettingzoo.classic import connect_four_v3
class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper):
"""Wrapper to allow PettingZoo environments to be used with SB3 illegal action masking."""
def reset(self, seed=None, options=None):
"""Gymnasium-like reset function which assigns obs/action spaces to be the same for each agent.
This is required as SB3 is designed for single-agent RL and doesn't expect obs/action spaces to be functions
"""
super().reset(seed, options)
# Strip the action mask out from the observation space
self.observation_space = super().observation_space(self.possible_agents[0])[
"observation"
]
self.action_space = super().action_space(self.possible_agents[0])
# Return initial observation, info (PettingZoo AEC envs do not by default)
return self.observe(self.agent_selection), {}
def step(self, action):
"""Gymnasium-like step function, returning observation, reward, termination, truncation, info.
The observation is for the next agent (used to determine the next action), while the remaining
items are for the agent that just acted (used to understand what just happened).
"""
current_agent = self.agent_selection
super().step(action)
next_agent = self.agent_selection
return (
self.observe(next_agent),
self._cumulative_rewards[current_agent],
self.terminations[current_agent],
self.truncations[current_agent],
self.infos[current_agent],
)
def observe(self, agent):
"""Return only raw observation, removing action mask."""
return super().observe(agent)["observation"]
def action_mask(self):
"""Separate function used in order to access the action mask."""
return super().observe(self.agent_selection)["action_mask"]
def mask_fn(env):
# Do whatever you'd like in this function to return the action mask
# for the current env. In this example, we assume the env has a
# helpful method we can rely on.
return env.action_mask()
def train_action_mask(env_fn, steps=10_000, seed=0, **env_kwargs):
"""Train a single model to play as each agent in a zero-sum game environment using invalid action masking."""
env = env_fn.env(**env_kwargs)
print(f"Starting training on {str(env.metadata['name'])}.")
# Custom wrapper to convert PettingZoo envs to work with SB3 action masking
env = SB3ActionMaskWrapper(env)
env.reset(seed=seed) # Must call reset() in order to re-define the spaces
env = ActionMasker(env, mask_fn) # Wrap to enable masking (SB3 function)
# MaskablePPO behaves the same as SB3's PPO unless the env is wrapped
# with ActionMasker. If the wrapper is detected, the masks are automatically
# retrieved and used when learning. Note that MaskablePPO does not accept
# a new action_mask_fn kwarg, as it did in an earlier draft.
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)
model.set_random_seed(seed)
model.learn(total_timesteps=steps)
model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")
print("Model has been saved.")
print(f"Finished training on {str(env.unwrapped.metadata['name'])}.\n")
env.close()
def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs):
# Evaluate a trained agent vs a random agent
env = env_fn.env(render_mode=render_mode, **env_kwargs)
print(
f"Starting evaluation vs a random agent. Trained agent will play as {env.possible_agents[1]}."
)
try:
latest_policy = max(
glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
)
except ValueError:
print("Policy not found.")
exit(0)
model = MaskablePPO.load(latest_policy)
scores = {agent: 0 for agent in env.possible_agents}
total_rewards = {agent: 0 for agent in env.possible_agents}
round_rewards = []
for i in range(num_games):
env.reset(seed=i)
env.action_space(env.possible_agents[0]).seed(i)
for agent in env.agent_iter():
obs, reward, termination, truncation, info = env.last()
# Separate observation and action mask
observation, action_mask = obs.values()
if termination or truncation:
# If there is a winner, keep track, otherwise don't change the scores (tie)
if (
env.rewards[env.possible_agents[0]]
!= env.rewards[env.possible_agents[1]]
):
winner = max(env.rewards, key=env.rewards.get)
scores[winner] += env.rewards[
winner
] # only tracks the largest reward (winner of game)
# Also track negative and positive rewards (penalizes illegal moves)
for a in env.possible_agents:
total_rewards[a] += env.rewards[a]
# List of rewards by round, for reference
round_rewards.append(env.rewards)
break
else:
if agent == env.possible_agents[0]:
act = env.action_space(agent).sample(action_mask)
else:
# Note: PettingZoo expects integer actions # TODO: change chess to cast actions to type int?
act = int(
model.predict(
observation, action_masks=action_mask, deterministic=True
)[0]
)
env.step(act)
env.close()
# Avoid dividing by zero
if sum(scores.values()) == 0:
winrate = 0
else:
winrate = scores[env.possible_agents[1]] / sum(scores.values())
print("Rewards by round: ", round_rewards)
print("Total rewards (incl. negative rewards): ", total_rewards)
print("Winrate: ", winrate)
print("Final scores: ", scores)
return round_rewards, total_rewards, winrate, scores
if __name__ == "__main__":
env_fn = connect_four_v3
env_kwargs = {}
# Evaluation/training hyperparameter notes:
# 10k steps: Winrate: 0.76, loss order of 1e-03
# 20k steps: Winrate: 0.86, loss order of 1e-04
# 40k steps: Winrate: 0.86, loss order of 7e-06
# Train a model against itself (takes ~20 seconds on a laptop CPU)
train_action_mask(env_fn, steps=20_480, seed=0, **env_kwargs)
# Evaluate 100 games against a random agent (winrate should be ~80%)
eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs)
# Watch two games vs a random agent
eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)
Testing other PettingZoo Classic environments¶
The following script uses pytest to test all other PettingZoo environments which support action masking.
This code yields decent results on simpler environments like Connect Four, while more difficult environments such as Chess or Hanabi will likely take much more training time and hyperparameter tuning.
"""Tests that action masking code works properly with all PettingZoo classic environments."""
import pytest
from pettingzoo.classic import (
chess_v6,
gin_rummy_v4,
go_v5,
hanabi_v5,
leduc_holdem_v4,
texas_holdem_no_limit_v6,
texas_holdem_v4,
tictactoe_v3,
)
pytest.importorskip("stable_baselines3")
pytest.importorskip("sb3_contrib")
# Note: Connect Four is tested in sb3_connect_four_action_mask.py
# Note: Rock-Paper-Scissors has no action masking and does not seem to learn well playing against itself
# These environments do better than random even after the minimum number of timesteps
EASY_ENVS = [
gin_rummy_v4,
texas_holdem_no_limit_v6, # texas holdem human rendered game ends instantly, but with random actions it works fine
tictactoe_v3,
leduc_holdem_v4,
]
# More difficult environments which will likely take more training time
MEDIUM_ENVS = [
hanabi_v5, # even with 10x as many steps, total score seems to always be tied between the two agents
texas_holdem_v4, # this performs poorly with updates to SB3 wrapper
chess_v6, # difficult to train because games take so long, performance varies heavily
]
# Most difficult environments to train agents for (and longest games
# TODO: test board_size to see if smaller go board is more easily solvable
HARD_ENVS = [
go_v5, # difficult to train because games take so long, may be another issue causing poor performance
]
@pytest.mark.parametrize("env_fn", EASY_ENVS)
def test_action_mask_easy(env_fn):
from tutorials.SB3.connect_four.sb3_connect_four_action_mask import (
eval_action_mask,
train_action_mask,
)
env_kwargs = {}
steps = 8192 * 4
# Train a model against itself (takes ~2 minutes on GPU)
train_action_mask(env_fn, steps=steps, seed=0, **env_kwargs)
# Evaluate 2 games against a random agent
round_rewards, total_rewards, winrate, scores = eval_action_mask(
env_fn, num_games=100, render_mode=None, **env_kwargs
)
assert winrate > 0.5 or (
total_rewards[env_fn.env().possible_agents[1]]
> total_rewards[env_fn.env().possible_agents[0]]
), "Trained policy should outperform random actions"
# Watch two games (disabled by default)
# eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)
# @pytest.mark.skip(
# reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI"
# )
@pytest.mark.parametrize("env_fn", MEDIUM_ENVS)
def test_action_mask_medium(env_fn):
from tutorials.SB3.connect_four.sb3_connect_four_action_mask import (
eval_action_mask,
train_action_mask,
)
env_kwargs = {}
# Train a model against itself
train_action_mask(env_fn, steps=8192, seed=0, **env_kwargs)
# Evaluate 2 games against a random agent
round_rewards, total_rewards, winrate, scores = eval_action_mask(
env_fn, num_games=100, render_mode=None, **env_kwargs
)
assert (
winrate < 0.75
), "Policy should not perform better than 75% winrate" # 30-40% for leduc, 0% for hanabi
# Watch two games (disabled by default)
# eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)
# @pytest.mark.skip(
# reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI"
# )
@pytest.mark.parametrize("env_fn", HARD_ENVS)
def test_action_mask_hard(env_fn):
from tutorials.SB3.connect_four.sb3_connect_four_action_mask import (
eval_action_mask,
train_action_mask,
)
env_kwargs = {}
# Train a model against itself
train_action_mask(env_fn, steps=8192, seed=0, **env_kwargs)
# Evaluate 2 games against a random agent
round_rewards, total_rewards, winrate, scores = eval_action_mask(
env_fn, num_games=100, render_mode=None, **env_kwargs
)
assert winrate > 0, "Policy should not perform better than 50% winrate" # 0% for go
# Watch two games (disabled by default)
# eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)