SB3: PPO for Waterworld#
This tutorial shows how to train agents using Proximal Policy Optimization (PPO) on the Waterworld environment (Parallel).
We use SuperSuit to create vectorized environments, leveraging multithreading to speed up training (see SB3’s vector environments documentation).
After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3’s model saving documentation).
Note
This environment has a discrete (1-dimensional) observation space, so we use an MLP feature extractor.
Environment Setup#
To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts.
pettingzoo[sisl]>=1.24.0
stable-baselines3>=2.0.0
supersuit>=3.9.0
pymunk
Code#
The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with SB3. If you have any questions, please feel free to ask in the Discord server.
Training and Evaluation#
"""Uses Stable-Baselines3 to train agents to play the Waterworld environment using SuperSuit vector envs.
For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
Author: Elliot (https://github.com/elliottower)
"""
from __future__ import annotations
import glob
import os
import time
import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from pettingzoo.sisl import waterworld_v4
def train_butterfly_supersuit(
env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs
):
# Train a single model to play as each agent in a cooperative Parallel environment
env = env_fn.parallel_env(**env_kwargs)
env.reset(seed=seed)
print(f"Starting training on {str(env.metadata['name'])}.")
env = ss.pettingzoo_env_to_vec_env_v1(env)
env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3")
# Note: Waterworld's observation space is discrete (242,) so we use an MLP policy rather than CNN
model = PPO(
MlpPolicy,
env,
verbose=3,
learning_rate=1e-3,
batch_size=256,
)
model.learn(total_timesteps=steps)
model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")
print("Model has been saved.")
print(f"Finished training on {str(env.unwrapped.metadata['name'])}.")
env.close()
def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs):
# Evaluate a trained agent vs a random agent
env = env_fn.env(render_mode=render_mode, **env_kwargs)
print(
f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})"
)
try:
latest_policy = max(
glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
)
except ValueError:
print("Policy not found.")
exit(0)
model = PPO.load(latest_policy)
rewards = {agent: 0 for agent in env.possible_agents}
# Note: We train using the Parallel API but evaluate using the AEC API
# SB3 models are designed for single-agent settings, we get around this by using he same model for every agent
for i in range(num_games):
env.reset(seed=i)
for agent in env.agent_iter():
obs, reward, termination, truncation, info = env.last()
for a in env.agents:
rewards[a] += env.rewards[a]
if termination or truncation:
break
else:
act = model.predict(obs, deterministic=True)[0]
env.step(act)
env.close()
avg_reward = sum(rewards.values()) / len(rewards.values())
print("Rewards: ", rewards)
print(f"Avg reward: {avg_reward}")
return avg_reward
if __name__ == "__main__":
env_fn = waterworld_v4
env_kwargs = {}
# Train a model (takes ~3 minutes on GPU)
train_butterfly_supersuit(env_fn, steps=196_608, seed=0, **env_kwargs)
# Evaluate 10 games (average reward should be positive but can vary significantly)
eval(env_fn, num_games=10, render_mode=None, **env_kwargs)
# Watch 2 games
eval(env_fn, num_games=2, render_mode="human", **env_kwargs)