SB3: Action Masked PPO for Connect Four

This tutorial shows how to train a agents using Maskable Proximal Policy Optimization (PPO) on the Connect Four environment (AEC).

It creates a custom Wrapper to convert to a Gymnasium-like environment which is compatible with SB3 action masking.

After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3’s documentation for more information).

Note

This environment has a discrete (1-dimensional) observation space with an illegal action mask, so we use a masked MLP feature extractor.

Warning

The SB3ActionMaskWrapper wrapper assumes that the action space and observation space is the same for each agent, this assumption may not hold for custom environments.

Environment Setup

To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts.

pettingzoo[classic]>=1.24.0
stable-baselines3>=2.0.0
sb3-contrib>=2.0.0

Code

The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the Discord server.

Training and Evaluation

"""Uses Stable-Baselines3 to train agents in the Connect Four environment using invalid action masking.

For information about invalid action masking in PettingZoo, see https://pettingzoo.farama.org/api/aec/#action-masking
For more information about invalid action masking in SB3, see https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html

Author: Elliot (https://github.com/elliottower)
"""
import glob
import os
import time

from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker

import pettingzoo.utils
from pettingzoo.classic import connect_four_v3


class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper):
    """Wrapper to allow PettingZoo environments to be used with SB3 illegal action masking."""

    def reset(self, seed=None, options=None):
        """Gymnasium-like reset function which assigns obs/action spaces to be the same for each agent.

        This is required as SB3 is designed for single-agent RL and doesn't expect obs/action spaces to be functions
        """
        super().reset(seed, options)

        # Strip the action mask out from the observation space
        self.observation_space = super().observation_space(self.possible_agents[0])[
            "observation"
        ]
        self.action_space = super().action_space(self.possible_agents[0])

        # Return initial observation, info (PettingZoo AEC envs do not by default)
        return self.observe(self.agent_selection), {}

    def step(self, action):
        """Gymnasium-like step function, returning observation, reward, termination, truncation, info.

        The observation is for the next agent (used to determine the next action), while the remaining
        items are for the agent that just acted (used to understand what just happened).
        """
        current_agent = self.agent_selection

        super().step(action)

        next_agent = self.agent_selection
        return (
            self.observe(next_agent),
            self._cumulative_rewards[current_agent],
            self.terminations[current_agent],
            self.truncations[current_agent],
            self.infos[current_agent],
        )

    def observe(self, agent):
        """Return only raw observation, removing action mask."""
        return super().observe(agent)["observation"]

    def action_mask(self):
        """Separate function used in order to access the action mask."""
        return super().observe(self.agent_selection)["action_mask"]


def mask_fn(env):
    # Do whatever you'd like in this function to return the action mask
    # for the current env. In this example, we assume the env has a
    # helpful method we can rely on.
    return env.action_mask()


def train_action_mask(env_fn, steps=10_000, seed=0, **env_kwargs):
    """Train a single model to play as each agent in a zero-sum game environment using invalid action masking."""
    env = env_fn.env(**env_kwargs)

    print(f"Starting training on {str(env.metadata['name'])}.")

    # Custom wrapper to convert PettingZoo envs to work with SB3 action masking
    env = SB3ActionMaskWrapper(env)

    env.reset(seed=seed)  # Must call reset() in order to re-define the spaces

    env = ActionMasker(env, mask_fn)  # Wrap to enable masking (SB3 function)
    # MaskablePPO behaves the same as SB3's PPO unless the env is wrapped
    # with ActionMasker. If the wrapper is detected, the masks are automatically
    # retrieved and used when learning. Note that MaskablePPO does not accept
    # a new action_mask_fn kwarg, as it did in an earlier draft.
    model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)
    model.set_random_seed(seed)
    model.learn(total_timesteps=steps)

    model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")

    print("Model has been saved.")

    print(f"Finished training on {str(env.unwrapped.metadata['name'])}.\n")

    env.close()


def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs):
    # Evaluate a trained agent vs a random agent
    env = env_fn.env(render_mode=render_mode, **env_kwargs)

    print(
        f"Starting evaluation vs a random agent. Trained agent will play as {env.possible_agents[1]}."
    )

    try:
        latest_policy = max(
            glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
        )
    except ValueError:
        print("Policy not found.")
        exit(0)

    model = MaskablePPO.load(latest_policy)

    scores = {agent: 0 for agent in env.possible_agents}
    total_rewards = {agent: 0 for agent in env.possible_agents}
    round_rewards = []

    for i in range(num_games):
        env.reset(seed=i)
        env.action_space(env.possible_agents[0]).seed(i)

        for agent in env.agent_iter():
            obs, reward, termination, truncation, info = env.last()

            # Separate observation and action mask
            observation, action_mask = obs.values()

            if termination or truncation:
                # If there is a winner, keep track, otherwise don't change the scores (tie)
                if (
                    env.rewards[env.possible_agents[0]]
                    != env.rewards[env.possible_agents[1]]
                ):
                    winner = max(env.rewards, key=env.rewards.get)
                    scores[winner] += env.rewards[
                        winner
                    ]  # only tracks the largest reward (winner of game)
                # Also track negative and positive rewards (penalizes illegal moves)
                for a in env.possible_agents:
                    total_rewards[a] += env.rewards[a]
                # List of rewards by round, for reference
                round_rewards.append(env.rewards)
                break
            else:
                if agent == env.possible_agents[0]:
                    act = env.action_space(agent).sample(action_mask)
                else:
                    # Note: PettingZoo expects integer actions # TODO: change chess to cast actions to type int?
                    act = int(
                        model.predict(
                            observation, action_masks=action_mask, deterministic=True
                        )[0]
                    )
            env.step(act)
    env.close()

    # Avoid dividing by zero
    if sum(scores.values()) == 0:
        winrate = 0
    else:
        winrate = scores[env.possible_agents[1]] / sum(scores.values())
    print("Rewards by round: ", round_rewards)
    print("Total rewards (incl. negative rewards): ", total_rewards)
    print("Winrate: ", winrate)
    print("Final scores: ", scores)
    return round_rewards, total_rewards, winrate, scores


if __name__ == "__main__":
    env_fn = connect_four_v3

    env_kwargs = {}

    # Evaluation/training hyperparameter notes:
    # 10k steps: Winrate:  0.76, loss order of 1e-03
    # 20k steps: Winrate:  0.86, loss order of 1e-04
    # 40k steps: Winrate:  0.86, loss order of 7e-06

    # Train a model against itself (takes ~20 seconds on a laptop CPU)
    train_action_mask(env_fn, steps=20_480, seed=0, **env_kwargs)

    # Evaluate 100 games against a random agent (winrate should be ~80%)
    eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs)

    # Watch two games vs a random agent
    eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)

Testing other PettingZoo Classic environments

The following script uses pytest to test all other PettingZoo environments which support action masking.

This code yields decent results on simpler environments like Connect Four, while more difficult environments such as Chess or Hanabi will likely take much more training time and hyperparameter tuning.

"""Tests that action masking code works properly with all PettingZoo classic environments."""

import pytest

from pettingzoo.classic import (
    chess_v6,
    gin_rummy_v4,
    go_v5,
    hanabi_v5,
    leduc_holdem_v4,
    texas_holdem_no_limit_v6,
    texas_holdem_v4,
    tictactoe_v3,
)

pytest.importorskip("stable_baselines3")
pytest.importorskip("sb3_contrib")

# Note: Connect Four is tested in sb3_connect_four_action_mask.py
# Note: Rock-Paper-Scissors has no action masking and does not seem to learn well playing against itself

# These environments do better than random even after the minimum number of timesteps
EASY_ENVS = [
    gin_rummy_v4,
    texas_holdem_no_limit_v6,  # texas holdem human rendered game ends instantly, but with random actions it works fine
    tictactoe_v3,
    leduc_holdem_v4,
]

# More difficult environments which will likely take more training time
MEDIUM_ENVS = [
    hanabi_v5,  # even with 10x as many steps, total score seems to always be tied between the two agents
    texas_holdem_v4,  # this performs poorly with updates to SB3 wrapper
    chess_v6,  # difficult to train because games take so long, performance varies heavily
]

# Most difficult environments to train agents for (and longest games
# TODO: test board_size to see if smaller go board is more easily solvable
HARD_ENVS = [
    go_v5,  # difficult to train because games take so long, may be another issue causing poor performance
]


@pytest.mark.parametrize("env_fn", EASY_ENVS)
def test_action_mask_easy(env_fn):
    from tutorials.SB3.connect_four.sb3_connect_four_action_mask import (
        eval_action_mask,
        train_action_mask,
    )

    env_kwargs = {}

    steps = 8192 * 4

    # Train a model against itself (takes ~2 minutes on GPU)
    train_action_mask(env_fn, steps=steps, seed=0, **env_kwargs)

    # Evaluate 2 games against a random agent
    round_rewards, total_rewards, winrate, scores = eval_action_mask(
        env_fn, num_games=100, render_mode=None, **env_kwargs
    )

    assert winrate > 0.5 or (
        total_rewards[env_fn.env().possible_agents[1]]
        > total_rewards[env_fn.env().possible_agents[0]]
    ), "Trained policy should outperform random actions"

    # Watch two games (disabled by default)
    # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)


# @pytest.mark.skip(
#     reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI"
# )
@pytest.mark.parametrize("env_fn", MEDIUM_ENVS)
def test_action_mask_medium(env_fn):
    from tutorials.SB3.connect_four.sb3_connect_four_action_mask import (
        eval_action_mask,
        train_action_mask,
    )

    env_kwargs = {}

    # Train a model against itself
    train_action_mask(env_fn, steps=8192, seed=0, **env_kwargs)

    # Evaluate 2 games against a random agent
    round_rewards, total_rewards, winrate, scores = eval_action_mask(
        env_fn, num_games=100, render_mode=None, **env_kwargs
    )

    assert (
        winrate < 0.75
    ), "Policy should not perform better than 75% winrate"  # 30-40% for leduc, 0% for hanabi

    # Watch two games (disabled by default)
    # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)


# @pytest.mark.skip(
#     reason="training is compute intensive and hyperparameters have not been tuned, disabled for CI"
# )
@pytest.mark.parametrize("env_fn", HARD_ENVS)
def test_action_mask_hard(env_fn):
    from tutorials.SB3.connect_four.sb3_connect_four_action_mask import (
        eval_action_mask,
        train_action_mask,
    )

    env_kwargs = {}

    # Train a model against itself
    train_action_mask(env_fn, steps=8192, seed=0, **env_kwargs)

    # Evaluate 2 games against a random agent
    round_rewards, total_rewards, winrate, scores = eval_action_mask(
        env_fn, num_games=100, render_mode=None, **env_kwargs
    )

    assert winrate > 0, "Policy should not perform better than 50% winrate"  # 0% for go

    # Watch two games (disabled by default)
    # eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)