SB3: PPO for Knights-Archers-Zombies#
This tutorial shows how to train agents using Proximal Policy Optimization (PPO) on the Knights-Archers-Zombies environment (AEC).
We use SuperSuit to create vectorized environments, leveraging multithreading to speed up training (see SB3’s vector environments documentation).
After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3’s model saving documentation).
If the observation space is visual (vector_state=False
in env_kwargs
), we pre-process using color reduction, resizing, and frame stacking, and use a CNN policy.
Note
This environment has a visual (3-dimensional) observation space, so we use a CNN feature extractor.
Note
This environment allows agents to spawn and die, so it requires using SuperSuit’s Black Death wrapper, which provides blank observations to dead agents rather than removing them from the environment.
Environment Setup#
To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts.
pettingzoo[butterfly]>=1.24.0
stable-baselines3>=2.0.0
supersuit>=3.9.0
Code#
The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the Discord server.
Training and Evaluation#
"""Uses Stable-Baselines3 to train agents in the Knights-Archers-Zombies environment using SuperSuit vector envs.
This environment requires using SuperSuit's Black Death wrapper, to handle agent death.
For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
Author: Elliot (https://github.com/elliottower)
"""
from __future__ import annotations
import glob
import os
import time
import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import CnnPolicy, MlpPolicy
from pettingzoo.butterfly import knights_archers_zombies_v10
def train(env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs):
# Train a single model to play as each agent in an AEC environment
env = env_fn.parallel_env(**env_kwargs)
# Add black death wrapper so the number of agents stays constant
# MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True
env = ss.black_death_v3(env)
# Pre-process using SuperSuit
visual_observation = not env.unwrapped.vector_state
if visual_observation:
# If the observation space is visual, reduce the color channels, resize from 512px to 84px, and apply frame stacking
env = ss.color_reduction_v0(env, mode="B")
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.frame_stack_v1(env, 3)
env.reset(seed=seed)
print(f"Starting training on {str(env.metadata['name'])}.")
env = ss.pettingzoo_env_to_vec_env_v1(env)
env = ss.concat_vec_envs_v1(env, 8, num_cpus=1, base_class="stable_baselines3")
# Use a CNN policy if the observation space is visual
model = PPO(
CnnPolicy if visual_observation else MlpPolicy,
env,
verbose=3,
batch_size=256,
)
model.learn(total_timesteps=steps)
model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")
print("Model has been saved.")
print(f"Finished training on {str(env.unwrapped.metadata['name'])}.")
env.close()
def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs):
# Evaluate a trained agent vs a random agent
env = env_fn.env(render_mode=render_mode, **env_kwargs)
# Pre-process using SuperSuit
visual_observation = not env.unwrapped.vector_state
if visual_observation:
# If the observation space is visual, reduce the color channels, resize from 512px to 84px, and apply frame stacking
env = ss.color_reduction_v0(env, mode="B")
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.frame_stack_v1(env, 3)
print(
f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})"
)
try:
latest_policy = max(
glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
)
except ValueError:
print("Policy not found.")
exit(0)
model = PPO.load(latest_policy)
rewards = {agent: 0 for agent in env.possible_agents}
# Note: we evaluate here using an AEC environments, to allow for easy A/B testing against random policies
# For example, we can see here that using a random agent for archer_0 results in less points than the trained agent
for i in range(num_games):
env.reset(seed=i)
env.action_space(env.possible_agents[0]).seed(i)
for agent in env.agent_iter():
obs, reward, termination, truncation, info = env.last()
for a in env.agents:
rewards[a] += env.rewards[a]
if termination or truncation:
break
else:
if agent == env.possible_agents[0]:
act = env.action_space(agent).sample()
else:
act = model.predict(obs, deterministic=True)[0]
env.step(act)
env.close()
avg_reward = sum(rewards.values()) / len(rewards.values())
avg_reward_per_agent = {
agent: rewards[agent] / num_games for agent in env.possible_agents
}
print(f"Avg reward: {avg_reward}")
print("Avg reward per agent, per game: ", avg_reward_per_agent)
print("Full rewards: ", rewards)
return avg_reward
if __name__ == "__main__":
env_fn = knights_archers_zombies_v10
# Set vector_state to false in order to use visual observations (significantly longer training time)
env_kwargs = dict(max_cycles=100, max_zombies=4, vector_state=True)
# Train a model (takes ~5 minutes on a laptop CPU)
train(env_fn, steps=81_920, seed=0, **env_kwargs)
# Evaluate 10 games (takes ~10 seconds on a laptop CPU)
eval(env_fn, num_games=10, render_mode=None, **env_kwargs)
# Watch 2 games (takes ~10 seconds on a laptop CPU)
eval(env_fn, num_games=2, render_mode="human", **env_kwargs)