RLlib: DQN for Simple Poker#
This tutorial shows how to train a Deep Q-Network (DQN) agent on the Leduc Hold’em environment (AEC).
After training, run the provided code to watch your trained agent play vs itself. See the documentation for more information.
Environment Setup#
To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts.
PettingZoo[classic,butterfly]>=1.24.0
Pillow>=9.4.0
ray[rllib]==2.7.0
SuperSuit>=3.9.0
torch>=1.13.1
tensorflow-probability>=0.19.0
Code#
The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLlib. If you have any questions, please feel free to ask in the Discord server.
Training the RL agent#
"""Uses Ray's RLlib to train agents to play Leduc Holdem.
Author: Rohan (https://github.com/Rohan138)
"""
import os
import ray
from gymnasium.spaces import Box, Discrete
from ray import tune
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel
from ray.rllib.env import PettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MAX
from ray.tune.registry import register_env
from pettingzoo.classic import leduc_holdem_v4
torch, nn = try_import_torch()
class TorchMaskedActions(DQNTorchModel):
"""PyTorch version of above ParametricActionsModel."""
def __init__(
self,
obs_space: Box,
action_space: Discrete,
num_outputs,
model_config,
name,
**kw,
):
DQNTorchModel.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kw
)
obs_len = obs_space.shape[0] - action_space.n
orig_obs_space = Box(
shape=(obs_len,), low=obs_space.low[:obs_len], high=obs_space.high[:obs_len]
)
self.action_embed_model = TorchFC(
orig_obs_space,
action_space,
action_space.n,
model_config,
name + "_action_embed",
)
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
action_mask = input_dict["obs"]["action_mask"]
# Compute the predicted action embedding
action_logits, _ = self.action_embed_model(
{"obs": input_dict["obs"]["observation"]}
)
# turns probit action mask into logit action mask
inf_mask = torch.clamp(torch.log(action_mask), -1e10, FLOAT_MAX)
return action_logits + inf_mask, state
def value_function(self):
return self.action_embed_model.value_function()
if __name__ == "__main__":
ray.init()
alg_name = "DQN"
ModelCatalog.register_custom_model("pa_model", TorchMaskedActions)
# function that outputs the environment you wish to register.
def env_creator():
env = leduc_holdem_v4.env()
return env
env_name = "leduc_holdem_v4"
register_env(env_name, lambda config: PettingZooEnv(env_creator()))
test_env = PettingZooEnv(env_creator())
obs_space = test_env.observation_space
act_space = test_env.action_space
config = (
DQNConfig()
.environment(env=env_name)
.rollouts(num_rollout_workers=1, rollout_fragment_length=30)
.training(
train_batch_size=200,
hiddens=[],
dueling=False,
model={"custom_model": "pa_model"},
)
.multi_agent(
policies={
"player_0": (None, obs_space, act_space, {}),
"player_1": (None, obs_space, act_space, {}),
},
policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
.debugging(
log_level="DEBUG"
) # TODO: change to ERROR to match pistonball example
.framework(framework="torch")
.exploration(
exploration_config={
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 0.1,
"final_epsilon": 0.0,
"epsilon_timesteps": 100000, # Timesteps over which to anneal epsilon.
}
)
)
tune.run(
alg_name,
name="DQN",
stop={"timesteps_total": 10000000 if not os.environ.get("CI") else 50000},
checkpoint_freq=10,
config=config.to_dict(),
)
Watching the trained RL agent play#
"""Uses Ray's RLlib to view trained agents playing Leduoc Holdem.
Author: Rohan (https://github.com/Rohan138)
"""
import argparse
import os
import numpy as np
import ray
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
from rllib_leduc_holdem import TorchMaskedActions
from pettingzoo.classic import leduc_holdem_v4
os.environ["SDL_VIDEODRIVER"] = "dummy"
parser = argparse.ArgumentParser(
description="Render pretrained policy loaded from checkpoint"
)
parser.add_argument(
"--checkpoint-path",
help="Path to the checkpoint. This path will likely be something like this: `~/ray_results/pistonball_v6/PPO/PPO_pistonball_v6_660ce_00000_0_2021-06-11_12-30-57/checkpoint_000050/checkpoint-50`",
)
args = parser.parse_args()
if args.checkpoint_path is None:
print("The following arguments are required: --checkpoint-path")
exit(0)
checkpoint_path = os.path.expanduser(args.checkpoint_path)
alg_name = "DQN"
ModelCatalog.register_custom_model("pa_model", TorchMaskedActions)
# function that outputs the environment you wish to register.
def env_creator():
env = leduc_holdem_v4.env()
return env
env = env_creator()
env_name = "leduc_holdem_v4"
register_env(env_name, lambda config: PettingZooEnv(env_creator()))
ray.init()
DQNAgent = Algorithm.from_checkpoint(checkpoint_path)
reward_sums = {a: 0 for a in env.possible_agents}
i = 0
env.reset()
for agent in env.agent_iter():
observation, reward, termination, truncation, info = env.last()
obs = observation["observation"]
reward_sums[agent] += reward
if termination or truncation:
action = None
else:
print(DQNAgent.get_policy(agent))
policy = DQNAgent.get_policy(agent)
batch_obs = {
"obs": {
"observation": np.expand_dims(observation["observation"], 0),
"action_mask": np.expand_dims(observation["action_mask"], 0),
}
}
batched_action, state_out, info = policy.compute_actions_from_input_dict(
batch_obs
)
single_action = batched_action[0]
action = single_action
env.step(action)
i += 1
env.render()
print("rewards:")
print(reward_sums)