AgileRL: Implementing DQN - Curriculum Learning and Self-play#

../../../_images/connect_four_self_opp.gif

Fig1: Agent trained to play Connect Four through self-play#

This tutorial shows how to train a DQN agent on the connect four classic environment.

This tutorial focuses on two techniques used in reinforcement learning - curriculum learning and self-play. Curriculum learning refers to training an agent on tasks of increasing difficulty in separate ‘lessons’. Imagine you were trying to become a chess world champion. You would not decide to learn to play chess by immediately taking on a grand master - it would be too difficult. Instead, you would practice against people of the same ability as you, improve slowly, and increasingly play against harder opponents until you were ready to compete with the best. The same concept applies to reinforcement learning models. Sometimes, tasks are too difficult to learn in one go, and so we must create a curriculum to guide an agent and teach it to solve our ultimate hard environment.

This tutorial also uses self-play. Self-play is a technique used in competitive reinforcement learning environments. An agent trains by playing against a copy of itself - the opponent - and learns to beat this opponent. The opponent is then updated to a copy of this better version of the agent, and the agent must then learn to beat itself again. This is done repeatedly, and the agent iteratively improves by exploiting its own weaknesses and discovering new strategies.

In this tutorial, self-play is treated as the final lesson in the curriculum. However, these two techniques can be used independently of each other, and with unlimited resources, self-play can beat agents trained with human-crafted lessons through curriculum learning. The Bitter Lesson by Richard Sutton provides an interesting take on curriculum learning and is definitely worth consideration from any engineer undertaking such a task. However, unlike Sutton, we do not all have the resources available to us that Deepmind and top institutions provide, and so one must be pragmatic when deciding how they will solve their own reinforcement learning problem. If you would like to discuss this exciting area of research further, please join the AgileRL Discord server and let us know what you think!

What is DQN?#

DQN (Deep Q-Network) is an extension of Q-learning that makes use of a replay buffer and target network to improve learning stability. For further information on DQN, check out the AgileRL documentation.

Can I use it?#

Action Space

Observation Space

Discrete

✔️

✔️

Continuous

✔️

Environment Setup#

To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts.

agilerl>=0.1.16
pettingzoo[classic,atari,mpe]>=1.23.1
SuperSuit>=3.9.0
torch>=2.0.1
numpy>=1.24.2
tqdm>=4.65.0
fastrand==1.3.0
gymnasium>=0.28.1
imageio>=2.31.1
Pillow>=9.5.0
PyYAML>=5.4.1
wandb>=0.13.10

Code#

Curriculum learning and self-play using DQN on Connect Four#

The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with AgileRL. If you have any questions, please feel free to ask in the Discord server.

This is a complicated tutorial, and so we will go through it in stages. The full code can be found at the end of this section. Although much of this tutorial contains content specific to the Connect Four environment, it serves to demonstrate how techniques can be applied more generally to other problems.

Imports#

Importing the following packages, functions and classes will enable us to run the tutorial.

Imports
import copy
import os
import random
from collections import deque
from datetime import datetime

import numpy as np
import torch
import wandb
import yaml
from agilerl.components.replay_buffer import ReplayBuffer
from agilerl.hpo.mutation import Mutations
from agilerl.hpo.tournament import TournamentSelection
from agilerl.utils.utils import initialPopulation
from tqdm import tqdm, trange

from pettingzoo.classic import connect_four_v3

Curriculum Learning#

First, we need to set up and modify our environment to enable curriculum learning. Curriculum learning is enabled by changing the environment that the agent trains in. This can be implemented by changing what happens when certain actions are taken - altering the next observation returned by the environment, or more simply by altering the reward. First, we will change the reward. By default, Connect Four uses the following rewards:

  • Win = +1

  • Lose = -1

  • Play continues = 0

To help guide our agent, we can introduce rewards for other outcomes in the environment, such as a small reward for placing 3 pieces in a row, or a small negative reward when the opponent manages the same feat. We can also use reward shaping to encourage our agent to explore more. In Connect Four, if playing against a random opponent, an easy way to win is to always play in the same column. An agent may find success doing this, and therefore not learn other, more sophisticated strategies that can help it win against better opponents. We may therefore elect to reward vertical wins slightly less than horizontal or diagonal wins, to encourage the agent to try winning in different ways. An example reward system could be defined as follows:

  • Win (horizontal or diagonal) = +1

  • Win (vertical) = +0.8

  • Three in a row = +0.05

  • Opponent three in a row = -0.05

  • Lose = -1

  • Play continues = 0

Config files#

It is best to use YAML config files to define the lessons in our curriculum and easily change and keep track of our settings. The first three lessons in our curriculum can be defined as follows:

Lesson 1
---
# Connect Four Lesson 1
# Train against random agent: 'random', weak opponent: 'weak', strong opponent: 'strong', or use self-play: 'self'
opponent: random
opponent_pool_size:       # Size of opponent pool for self-play
opponent_upgrade:       # Epoch frequency to update opponent pool
eval_opponent:       # 'random', 'weak' or 'strong'
pretrained_path:       # Path to pretrained model weights
save_path: models/DQN/lesson1_trained_agent.pt  # Path to save trained model
max_train_episodes: 0  # Maximum number of training episodes in environment

## Game specific:
buffer_warm_up: true  # Fill replay buffer with random experiences
warm_up_opponent: random  # Difficulty level of warm up experiences
agent_warm_up: 30000  # Number of epochs to warm up agent by training on random experiences
block_vert_coef: 4  # How many times more likely to block vertically
rewards:  # Rewards for different outcomes
    win: 1
    vertical_win: 0.7
    three_in_row: 0.05
    opp_three_in_row: -0.05
    lose: -1
    play_continues: 0
Lesson 2
---
# Connect Four Lesson 2
# Train against random agent: 'random', weak opponent: 'weak', strong opponent: 'strong', or use self-play: 'self'
opponent: weak
opponent_pool_size:       # Size of opponent pool for self-play
opponent_upgrade:       # Epoch frequency to update opponent pool
eval_opponent: weak  # 'random', 'weak' or 'strong'
pretrained_path: models/DQN/lesson1_trained_agent.pt  # Path to pretrained model weights
save_path: models/DQN/lesson2_trained_agent.pt  # Path to save trained model
max_train_episodes: 100000  # Maximum number of training episodes in environment

## Game specific:
buffer_warm_up: false  # Fill replay buffer with random experiences
warm_up_opponent:       # Difficulty level of warm up experiences
agent_warm_up: 0  # Number of epochs to warm up agent by training on random experiences
block_vert_coef: 1  # How many times more likely to block vertically
rewards:  # Rewards for different outcomes
    win: 1
    vertical_win: 1
    three_in_row: 0.02
    opp_three_in_row: -0.02
    lose: -1
    play_continues: 0
Lesson 3
---
# Connect Four Lesson 3
# Train against random agent: 'random', weak opponent: 'weak', strong opponent: 'strong', or use self-play: 'self'
opponent: strong
opponent_pool_size:       # Size of opponent pool for self-play
opponent_upgrade:       # Epoch frequency to update opponent pool
eval_opponent: strong  # 'random', 'weak' or 'strong'
pretrained_path: models/DQN/lesson2_trained_agent.pt   # Path to pretrained model weights
save_path: models/DQN/lesson3_trained_agent.pt  # Path to save trained model
max_train_episodes: 200000  # Maximum number of training episodes in environment

## Game specific:
buffer_warm_up: false  # Fill replay buffer with random experiences
warm_up_opponent:  # Difficulty level of warm up experiences
agent_warm_up: 0  # Number of epochs to warm up agent by training on random experiences
block_vert_coef: 1  # How many times more likely to block vertically
rewards:  # Rewards for different outcomes
    win: 1
    vertical_win: 1
    three_in_row: 0.02
    opp_three_in_row: -0.02
    lose: -1
    play_continues: 0

To implement our curriculum, we create a CurriculumEnv class that acts as a wrapper on top of our Connect Four environment and enables us to alter the reward to guide the training of our agent. This uses the configs that we set up to define the lesson.

CurriculumEnv
class CurriculumEnv:
   """Wrapper around environment to modify reward for curriculum learning.

   :param env: Environment to learn in
   :type env: PettingZoo-style environment
   :param lesson: Lesson settings for curriculum learning
   :type lesson: dict
   """

   def __init__(self, env, lesson):
      self.env = env
      self.lesson = lesson

   def fill_replay_buffer(self, memory, opponent):
      """Fill the replay buffer with experiences collected by taking random actions in the environment.

      :param memory: Experience replay buffer
      :type memory: AgileRL experience replay buffer
      """
      print("Filling replay buffer ...")

      pbar = tqdm(total=memory.memory_size)
      while len(memory) < memory.memory_size:
         # Randomly decide whether random player will go first or second
         if random.random() > 0.5:
               opponent_first = False
         else:
               opponent_first = True

         mem_full = len(memory)
         self.reset()  # Reset environment at start of episode
         observation, reward, done, truncation, _ = self.last()

         (
               p1_state,
               p1_state_flipped,
               p1_action,
               p1_next_state,
               p1_next_state_flipped,
         ) = (None, None, None, None, None)
         done, truncation = False, False

         while not (done or truncation):
               # Player 0's turn
               p0_action_mask = observation["action_mask"]
               p0_state = np.moveaxis(observation["observation"], [-1], [-3])
               p0_state_flipped = np.expand_dims(np.flip(p0_state, 2), 0)
               p0_state = np.expand_dims(p0_state, 0)
               if opponent_first:
                  p0_action = self.env.action_space("player_0").sample(p0_action_mask)
               else:
                  if self.lesson["warm_up_opponent"] == "random":
                     p0_action = opponent.getAction(
                           p0_action_mask, p1_action, self.lesson["block_vert_coef"]
                     )
                  else:
                     p0_action = opponent.getAction(player=0)
               self.step(p0_action)  # Act in environment
               observation, env_reward, done, truncation, _ = self.last()
               p0_next_state = np.moveaxis(observation["observation"], [-1], [-3])
               p0_next_state_flipped = np.expand_dims(np.flip(p0_next_state, 2), 0)
               p0_next_state = np.expand_dims(p0_next_state, 0)

               if done or truncation:
                  reward = self.reward(done=True, player=0)
                  memory.save2memoryVectEnvs(
                     np.concatenate(
                           (p0_state, p1_state, p0_state_flipped, p1_state_flipped)
                     ),
                     [p0_action, p1_action, 6 - p0_action, 6 - p1_action],
                     [
                           reward,
                           LESSON["rewards"]["lose"],
                           reward,
                           LESSON["rewards"]["lose"],
                     ],
                     np.concatenate(
                           (
                              p0_next_state,
                              p1_next_state,
                              p0_next_state_flipped,
                              p1_next_state_flipped,
                           )
                     ),
                     [done, done, done, done],
                  )
               else:  # Play continues
                  if p1_state is not None:
                     reward = self.reward(done=False, player=1)
                     memory.save2memoryVectEnvs(
                           np.concatenate((p1_state, p1_state_flipped)),
                           [p1_action, 6 - p1_action],
                           [reward, reward],
                           np.concatenate((p1_next_state, p1_next_state_flipped)),
                           [done, done],
                     )

                  # Player 1's turn
                  p1_action_mask = observation["action_mask"]
                  p1_state = np.moveaxis(observation["observation"], [-1], [-3])
                  p1_state[[0, 1], :, :] = p1_state[[0, 1], :, :]
                  p1_state_flipped = np.expand_dims(np.flip(p1_state, 2), 0)
                  p1_state = np.expand_dims(p1_state, 0)
                  if not opponent_first:
                     p1_action = self.env.action_space("player_1").sample(
                           p1_action_mask
                     )
                  else:
                     if self.lesson["warm_up_opponent"] == "random":
                           p1_action = opponent.getAction(
                              p1_action_mask, p0_action, LESSON["block_vert_coef"]
                           )
                     else:
                           p1_action = opponent.getAction(player=1)
                  self.step(p1_action)  # Act in environment
                  observation, env_reward, done, truncation, _ = self.last()
                  p1_next_state = np.moveaxis(observation["observation"], [-1], [-3])
                  p1_next_state[[0, 1], :, :] = p1_next_state[[0, 1], :, :]
                  p1_next_state_flipped = np.expand_dims(np.flip(p1_next_state, 2), 0)
                  p1_next_state = np.expand_dims(p1_next_state, 0)

                  if done or truncation:
                     reward = self.reward(done=True, player=1)
                     memory.save2memoryVectEnvs(
                           np.concatenate(
                              (p0_state, p1_state, p0_state_flipped, p1_state_flipped)
                           ),
                           [p0_action, p1_action, 6 - p0_action, 6 - p1_action],
                           [
                              LESSON["rewards"]["lose"],
                              reward,
                              LESSON["rewards"]["lose"],
                              reward,
                           ],
                           np.concatenate(
                              (
                                 p0_next_state,
                                 p1_next_state,
                                 p0_next_state_flipped,
                                 p1_next_state_flipped,
                              )
                           ),
                           [done, done, done, done],
                     )

                  else:  # Play continues
                     reward = self.reward(done=False, player=0)
                     memory.save2memoryVectEnvs(
                           np.concatenate((p0_state, p0_state_flipped)),
                           [p0_action, 6 - p0_action],
                           [reward, reward],
                           np.concatenate((p0_next_state, p0_next_state_flipped)),
                           [done, done],
                     )

         pbar.update(len(memory) - mem_full)
      pbar.close()
      print("Replay buffer warmed up.")
      return memory

   def check_winnable(self, lst, piece):
      """Checks if four pieces in a row represent a winnable opportunity, e.g. [1, 1, 1, 0] or [2, 0, 2, 2].

      :param lst: List of pieces in row
      :type lst: List
      :param piece: Player piece we are checking (1 or 2)
      :type piece: int
      """
      return lst.count(piece) == 3 and lst.count(0) == 1

   def check_vertical_win(self, player):
      """Checks if a win is vertical.

      :param player: Player who we are checking, 0 or 1
      :type player: int
      """
      board = np.array(self.env.env.board).reshape(6, 7)
      piece = player + 1

      column_count = 7
      row_count = 6

      # Check vertical locations for win
      for c in range(column_count):
         for r in range(row_count - 3):
               if (
                  board[r][c] == piece
                  and board[r + 1][c] == piece
                  and board[r + 2][c] == piece
                  and board[r + 3][c] == piece
               ):
                  return True
      return False

   def check_three_in_row(self, player):
      """Checks if there are three pieces in a row and a blank space next, or two pieces - blank - piece.

      :param player: Player who we are checking, 0 or 1
      :type player: int
      """
      board = np.array(self.env.env.board).reshape(6, 7)
      piece = player + 1

      # Check horizontal locations
      column_count = 7
      row_count = 6
      three_in_row_count = 0

      # Check vertical locations
      for c in range(column_count):
         for r in range(row_count - 3):
               if self.check_winnable(board[r : r + 4, c].tolist(), piece):
                  three_in_row_count += 1

      # Check horizontal locations
      for r in range(row_count):
         for c in range(column_count - 3):
               if self.check_winnable(board[r, c : c + 4].tolist(), piece):
                  three_in_row_count += 1

      # Check positively sloped diagonals
      for c in range(column_count - 3):
         for r in range(row_count - 3):
               if self.check_winnable(
                  [
                     board[r, c],
                     board[r + 1, c + 1],
                     board[r + 2, c + 2],
                     board[r + 3, c + 3],
                  ],
                  piece,
               ):
                  three_in_row_count += 1

      # Check negatively sloped diagonals
      for c in range(column_count - 3):
         for r in range(3, row_count):
               if self.check_winnable(
                  [
                     board[r, c],
                     board[r - 1, c + 1],
                     board[r - 2, c + 2],
                     board[r - 3, c + 3],
                  ],
                  piece,
               ):
                  three_in_row_count += 1

      return three_in_row_count

   def reward(self, done, player):
      """Processes and returns reward from environment according to lesson criteria.

      :param done: Environment has terminated
      :type done: bool
      :param player: Player who we are checking, 0 or 1
      :type player: int
      """
      if done:
         reward = (
               self.lesson["rewards"]["vertical_win"]
               if self.check_vertical_win(player)
               else self.lesson["rewards"]["win"]
         )
      else:
         agent_three_count = self.check_three_in_row(1 - player)
         opp_three_count = self.check_three_in_row(player)
         if (agent_three_count + opp_three_count) == 0:
               reward = self.lesson["rewards"]["play_continues"]
         else:
               reward = (
                  self.lesson["rewards"]["three_in_row"] * agent_three_count
                  + self.lesson["rewards"]["opp_three_in_row"] * opp_three_count
               )
      return reward

   def last(self):
      """Wrapper around PettingZoo env last method."""
      return self.env.last()

   def step(self, action):
      """Wrapper around PettingZoo env step method."""
      self.env.step(action)

   def reset(self):
      """Wrapper around PettingZoo env reset method."""
      self.env.reset()

When defining the different lessons in our curriculum, we can increase the difficulty of our task by modifying environment observations for our agent - in Connect Four, we can increase the skill level of our opponent. By progressively doing this, we can help our agent improve. We can change our rewards between lessons too; for example, we may wish to reward wins in all directions equally once we have learned to beat a random agent and now wish to train against a harder opponent. In this tutorial, an Opponent class is implemented to provide different levels of difficulty for training our agent.

Opponent
class Opponent:
   """Connect 4 opponent to train and/or evaluate against.

   :param env: Environment to learn in
   :type env: PettingZoo-style environment
   :param difficulty: Difficulty level of opponent, 'random', 'weak' or 'strong'
   :type difficulty: str
   """

   def __init__(self, env, difficulty):
      self.env = env.env
      self.difficulty = difficulty
      if self.difficulty == "random":
         self.getAction = self.random_opponent
      elif self.difficulty == "weak":
         self.getAction = self.weak_rule_based_opponent
      else:
         self.getAction = self.strong_rule_based_opponent
      self.num_cols = 7
      self.num_rows = 6
      self.length = 4
      self.top = [0] * self.num_cols

   def update_top(self):
      """Updates self.top, a list which tracks the row on top of the highest piece in each column."""
      board = np.array(self.env.env.board).reshape(self.num_rows, self.num_cols)
      non_zeros = np.where(board != 0)
      rows, cols = non_zeros
      top = np.zeros(board.shape[1], dtype=int)
      for col in range(board.shape[1]):
         column_pieces = rows[cols == col]
         if len(column_pieces) > 0:
               top[col] = np.min(column_pieces) - 1
         else:
               top[col] = 5
      full_columns = np.all(board != 0, axis=0)
      top[full_columns] = 6
      self.top = top

   def random_opponent(self, action_mask, last_opp_move=None, block_vert_coef=1):
      """Takes move for random opponent. If the lesson aims to randomly block vertical wins with a higher probability, this is done here too.

      :param action_mask: Mask of legal actions: 1=legal, 0=illegal
      :type action_mask: List
      :param last_opp_move: Most recent action taken by agent against this opponent
      :type last_opp_move: int
      :param block_vert_coef: How many times more likely to block vertically
      :type block_vert_coef: float
      """
      if last_opp_move is not None:
         action_mask[last_opp_move] *= block_vert_coef
      action = random.choices(list(range(self.num_cols)), action_mask)[0]
      return action

   def weak_rule_based_opponent(self, player):
      """Takes move for weak rule-based opponent.

      :param player: Player who we are checking, 0 or 1
      :type player: int
      """
      self.update_top()
      max_length = -1
      best_actions = []
      for action in range(self.num_cols):
         possible, reward, ended, lengths = self.outcome(
               action, player, return_length=True
         )
         if possible and lengths.sum() > max_length:
               best_actions = []
               max_length = lengths.sum()
         if possible and lengths.sum() == max_length:
               best_actions.append(action)
      best_action = random.choice(best_actions)
      return best_action

   def strong_rule_based_opponent(self, player):
      """Takes move for strong rule-based opponent.

      :param player: Player who we are checking, 0 or 1
      :type player: int
      """
      self.update_top()

      winning_actions = []
      for action in range(self.num_cols):
         possible, reward, ended = self.outcome(action, player)
         if possible and ended:
               winning_actions.append(action)
      if len(winning_actions) > 0:
         winning_action = random.choice(winning_actions)
         return winning_action

      opp = 1 if player == 0 else 0
      loss_avoiding_actions = []
      for action in range(self.num_cols):
         possible, reward, ended = self.outcome(action, opp)
         if possible and ended:
               loss_avoiding_actions.append(action)
      if len(loss_avoiding_actions) > 0:
         loss_avoiding_action = random.choice(loss_avoiding_actions)
         return loss_avoiding_action

      return self.weak_rule_based_opponent(player)  # take best possible move

   def outcome(self, action, player, return_length=False):
      """Takes move for weak rule-based opponent.

      :param action: Action to take in environment
      :type action: int
      :param player: Player who we are checking, 0 or 1
      :type player: int
      :param return_length: Return length of outcomes, defaults to False
      :type player: bool, optional
      """
      if not (self.top[action] < self.num_rows):  # action column is full
         return (False, None, None) + ((None,) if return_length else ())

      row, col = self.top[action], action
      piece = player + 1

      # down, up, left, right, down-left, up-right, down-right, up-left,
      directions = np.array(
         [
               [[-1, 0], [1, 0]],
               [[0, -1], [0, 1]],
               [[-1, -1], [1, 1]],
               [[-1, 1], [1, -1]],
         ]
      )  # |4x2x2|

      positions = np.array([row, col]).reshape(1, 1, 1, -1) + np.expand_dims(
         directions, -2
      ) * np.arange(1, self.length).reshape(
         1, 1, -1, 1
      )  # |4x2x3x2|
      valid_positions = np.logical_and(
         np.logical_and(
               positions[:, :, :, 0] >= 0, positions[:, :, :, 0] < self.num_rows
         ),
         np.logical_and(
               positions[:, :, :, 1] >= 0, positions[:, :, :, 1] < self.num_cols
         ),
      )  # |4x2x3|
      d0 = np.where(valid_positions, positions[:, :, :, 0], 0)
      d1 = np.where(valid_positions, positions[:, :, :, 1], 0)
      board = np.array(self.env.env.board).reshape(self.num_rows, self.num_cols)
      board_values = np.where(valid_positions, board[d0, d1], 0)
      a = (board_values == piece).astype(int)
      b = np.concatenate(
         (a, np.zeros_like(a[:, :, :1])), axis=-1
      )  # padding with zeros to compute length
      lengths = np.argmin(b, -1)

      ended = False
      # check if winnable in any direction
      for both_dir in board_values:
         # |2x3|
         line = np.concatenate((both_dir[0][::-1], [piece], both_dir[1]))
         if "".join(map(str, [piece] * self.length)) in "".join(map(str, line)):
               ended = True
               break

      # ended = np.any(np.greater_equal(np.sum(lengths, 1), self.length - 1))
      draw = True
      for c, v in enumerate(self.top):
         draw &= (v == self.num_rows) if c != col else (v == (self.num_rows - 1))
      ended |= draw
      reward = (-1) ** (player) if ended and not draw else 0

      return (True, reward, ended) + ((lengths,) if return_length else ())

General setup#

Before we go any further in this tutorial, it would be helpful to define and set up everything remaining we need for training.

Setup code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("===== AgileRL Curriculum Learning Demo =====")

lesson_number = 1

# Load lesson for curriculum
with open(f"./curriculums/connect_four/lesson{lesson_number}.yaml") as file:
   LESSON = yaml.safe_load(file)

# Define the network configuration
NET_CONFIG = {
   "arch": "cnn",  # Network architecture
   "h_size": [64, 64],  # Actor hidden size
   "c_size": [128],  # CNN channel size
   "k_size": [4],  # CNN kernel size
   "s_size": [1],  # CNN stride size
   "normalize": False,  # Normalize image from range [0,255] to [0,1]
}

# Define the initial hyperparameters
INIT_HP = {
   "POPULATION_SIZE": 6,
   # "ALGO": "Rainbow DQN",  # Algorithm
   "ALGO": "DQN",  # Algorithm
   "DOUBLE": True,
   # Swap image channels dimension from last to first [H, W, C] -> [C, H, W]
   "BATCH_SIZE": 256,  # Batch size
   "LR": 1e-4,  # Learning rate
   "GAMMA": 0.99,  # Discount factor
   "MEMORY_SIZE": 100000,  # Max memory buffer size
   "LEARN_STEP": 1,  # Learning frequency
   "N_STEP": 1,  # Step number to calculate td error
   "PER": False,  # Use prioritized experience replay buffer
   "ALPHA": 0.6,  # Prioritized replay buffer parameter
   "TAU": 0.01,  # For soft update of target parameters
   "BETA": 0.4,  # Importance sampling coefficient
   "PRIOR_EPS": 0.000001,  # Minimum priority for sampling
   "NUM_ATOMS": 51,  # Unit number of support
   "V_MIN": 0.0,  # Minimum value of support
   "V_MAX": 200.0,  # Maximum value of support
   "WANDB": False,  # Use Weights and Biases tracking
}

# Define the connect four environment
env = connect_four_v3.env()
env.reset()

# Configure the algo input arguments
state_dim = [
   env.observation_space(agent)["observation"].shape for agent in env.agents
]
one_hot = False
action_dim = [env.action_space(agent).n for agent in env.agents]
INIT_HP["DISCRETE_ACTIONS"] = True
INIT_HP["MAX_ACTION"] = None
INIT_HP["MIN_ACTION"] = None

# Warp the environment in the curriculum learning wrapper
env = CurriculumEnv(env, LESSON)

# Pre-process dimensions for PyTorch layers
# We only need to worry about the state dim of a single agent
# We flatten the 6x7x2 observation as input to the agent"s neural network
state_dim = np.moveaxis(np.zeros(state_dim[0]), [-1], [-3]).shape
action_dim = action_dim[0]

# Create a population ready for evolutionary hyper-parameter optimisation
pop = initialPopulation(
   INIT_HP["ALGO"],
   state_dim,
   action_dim,
   one_hot,
   NET_CONFIG,
   INIT_HP,
   population_size=INIT_HP["POPULATION_SIZE"],
   device=device,
)

# Configure the replay buffer
field_names = ["state", "action", "reward", "next_state", "done"]
memory = ReplayBuffer(
   action_dim=action_dim,  # Number of agent actions
   memory_size=INIT_HP["MEMORY_SIZE"],  # Max replay buffer size
   field_names=field_names,  # Field names to store in memory
   device=device,
)

# Instantiate a tournament selection object (used for HPO)
tournament = TournamentSelection(
   tournament_size=2,  # Tournament selection size
   elitism=True,  # Elitism in tournament selection
   population_size=INIT_HP["POPULATION_SIZE"],  # Population size
   evo_step=1,
)  # Evaluate using last N fitness scores

# Instantiate a mutations object (used for HPO)
mutations = Mutations(
   algo=INIT_HP["ALGO"],
   no_mutation=0.2,  # Probability of no mutation
   architecture=0,  # Probability of architecture mutation
   new_layer_prob=0.2,  # Probability of new layer mutation
   parameters=0.2,  # Probability of parameter mutation
   activation=0,  # Probability of activation function mutation
   rl_hp=0.2,  # Probability of RL hyperparameter mutation
   rl_hp_selection=[
         "lr",
         "learn_step",
         "batch_size",
   ],  # RL hyperparams selected for mutation
   mutation_sd=0.1,  # Mutation strength
   # Define search space for each hyperparameter
   min_lr=0.0001,
   max_lr=0.01,
   min_learn_step=1,
   max_learn_step=120,
   min_batch_size=8,
   max_batch_size=64,
   arch=NET_CONFIG["arch"],  # MLP or CNN
   rand_seed=1,
   device=device,
)

# Define training loop parameters
episodes_per_epoch = 10
max_episodes = LESSON["max_train_episodes"]  # Total episodes
max_steps = 500  # Maximum steps to take in each episode
evo_epochs = 20  # Evolution frequency
evo_loop = 50  # Number of evaluation episodes
elite = pop[0]  # Assign a placeholder "elite" agent
epsilon = 1.0  # Starting epsilon value
eps_end = 0.1  # Final epsilon value
eps_decay = 0.9998  # Epsilon decays
opp_update_counter = 0
wb = INIT_HP["WANDB"]

As part of the curriculum, we may also choose to fill the replay buffer with random experiences, and also train on these offline.

Fill replay buffer
# Perform buffer and agent warmups if desired
if LESSON["buffer_warm_up"]:
   warm_up_opponent = Opponent(env, difficulty=LESSON["warm_up_opponent"])
   memory = env.fill_replay_buffer(
         memory, warm_up_opponent
   )  # Fill replay buffer with transitions
   if LESSON["agent_warm_up"] > 0:
         print("Warming up agents ...")
         agent = pop[0]
         # Train on randomly collected samples
         for epoch in trange(LESSON["agent_warm_up"]):
            experiences = memory.sample(agent.batch_size)
            agent.learn(experiences)
         pop = [agent.clone() for _ in pop]
         elite = agent
         print("Agent population warmed up.")

Self-play#

In this tutorial, we use self-play as the final lesson in our curriculum. By iteratively improving our agent and making it learn to win against itself, we can allow it to discover new strategies and achieve higher performance. The weights of our pretrained agent from an earlier lesson can be loaded to the population as follows:

Load pretrained weights
if LESSON["pretrained_path"] is not None:
   for agent in pop:
         # Load pretrained checkpoint
         agent.loadCheckpoint(LESSON["pretrained_path"])
         # Reinit optimizer for new task
         agent.lr = INIT_HP["LR"]
         agent.optimizer = torch.optim.Adam(
            agent.actor.parameters(), lr=agent.lr
         )

To train against an old version of our agent, we create a pool of opponents. At training time, we randomly select an opponent from this pool. At regular intervals, we update the opponent pool by removing the oldest opponent and adding a copy of the latest version of our agent. This provides a balance between training against an increasingly difficult opponent and providing variety in the moves an opponent might make.

Create opponent pool
if LESSON["opponent"] == "self":
   # Create initial pool of opponents
   opponent_pool = deque(maxlen=LESSON["opponent_pool_size"])
   for _ in range(LESSON["opponent_pool_size"]):
         opp = copy.deepcopy(pop[0])
         opp.actor.load_state_dict(pop[0].actor.state_dict())
         opp.actor.eval()
         opponent_pool.append(opp)

A sample lesson config for self-play training could be defined as follows:

Lesson 4
---
# Connect Four Lesson 4
# Train against random agent: 'random', weak opponent: 'weak', strong opponent: 'strong', or use self-play: 'self'
opponent: self
opponent_pool_size: 6  # Size of opponent pool for self-play
opponent_upgrade: 6000  # Epoch frequency to update opponent pool
eval_opponent: strong  # 'random', 'weak' or 'strong'
pretrained_path: models/DQN/lesson3_trained_agent.pt  # Path to pretrained model weights
save_path: models/DQN/lesson4_trained_agent.pt  # Path to save trained model
max_train_episodes: 600000  # Maximum number of training episodes in environment

## Game specific:
buffer_warm_up: false  # Fill replay buffer with random experiences
warm_up_opponent:       # Difficulty level of warm up experiences
agent_warm_up: 0  # Number of epochs to warm up agent by training on random experiences
block_vert_coef: 1  # How many times more likely to block vertically if playing random opponent
rewards:  # Rewards for different outcomes
    win: 1
    vertical_win: 1
    three_in_row: 0.01
    opp_three_in_row: -0.01
    lose: -1
    play_continues: 0

It could also be possible to train an agent through self-play only, without using any previous lessons in the curriculum. This would require significant training time, but could ultimately result in better performance than other methods, and could avoid some of the mistakes discussed in The Bitter Lesson.

Training loop#

The Connect Four training loop must take into account that the agent only takes an action every other interaction with the environment (the opponent takes alternating turns). This must be considered when saving transitions to the replay buffer. Equally, we must wait for the outcome of the next player’s turn before determining what the reward should be for a transition. This is not a true Markov Decision Process for this reason, but we can still train a reinforcement learning agent reasonably successfully in these non-stationary conditions.

At regular intervals, we evaluate the performance, or ‘fitness’, of the agents in our population, and do an evolutionary step. Those which perform best are more likely to become members of the next generation, and the hyperparameters and neural architectures of agents in the population are mutated. This evolution allows us to optimize hyperparameters and maximise the performance of our agents in a single training run.

Training loop
if max_episodes > 0:
   if wb:
      wandb.init(
            # set the wandb project where this run will be logged
            project="AgileRL",
            name="{}-EvoHPO-{}-{}Opposition-CNN-{}".format(
               "connect_four_v3",
               INIT_HP["ALGO"],
               LESSON["opponent"],
               datetime.now().strftime("%m%d%Y%H%M%S"),
            ),
            # track hyperparameters and run metadata
            config={
               "algo": "Evo HPO Rainbow DQN",
               "env": "connect_four_v3",
               "INIT_HP": INIT_HP,
               "lesson": LESSON,
            },
      )

total_steps = 0
total_episodes = 0
pbar = trange(int(max_episodes / episodes_per_epoch))

# Training loop
for idx_epi in pbar:
   turns_per_episode = []
   train_actions_hist = [0] * action_dim
   for agent in pop:  # Loop through population
         for episode in range(episodes_per_epoch):
            env.reset()  # Reset environment at start of episode
            observation, env_reward, done, truncation, _ = env.last()

            (
               p1_state,
               p1_state_flipped,
               p1_action,
               p1_next_state,
               p1_next_state_flipped,
            ) = (None, None, None, None, None)

            if LESSON["opponent"] == "self":
               # Randomly choose opponent from opponent pool if using self-play
               opponent = random.choice(opponent_pool)
            else:
               # Create opponent of desired difficulty
               opponent = Opponent(env, difficulty=LESSON["opponent"])

            # Randomly decide whether agent will go first or second
            if random.random() > 0.5:
               opponent_first = False
            else:
               opponent_first = True

            score = 0
            turns = 0  # Number of turns counter

            for idx_step in range(max_steps):
               # Player 0"s turn
               p0_action_mask = observation["action_mask"]
               p0_state = np.moveaxis(observation["observation"], [-1], [-3])
               p0_state_flipped = np.expand_dims(np.flip(p0_state, 2), 0)
               p0_state = np.expand_dims(p0_state, 0)

               if opponent_first:
                     if LESSON["opponent"] == "self":
                        p0_action = opponent.getAction(
                           p0_state, 0, p0_action_mask
                        )[0]
                     elif LESSON["opponent"] == "random":
                        p0_action = opponent.getAction(
                           p0_action_mask, p1_action, LESSON["block_vert_coef"]
                        )
                     else:
                        p0_action = opponent.getAction(player=0)
               else:
                     p0_action = agent.getAction(
                        p0_state, epsilon, p0_action_mask
                     )[
                        0
                     ]  # Get next action from agent
                     train_actions_hist[p0_action] += 1

               env.step(p0_action)  # Act in environment
               observation, env_reward, done, truncation, _ = env.last()
               p0_next_state = np.moveaxis(
                     observation["observation"], [-1], [-3]
               )
               p0_next_state_flipped = np.expand_dims(
                     np.flip(p0_next_state, 2), 0
               )
               p0_next_state = np.expand_dims(p0_next_state, 0)

               if not opponent_first:
                     score += env_reward
               turns += 1

               # Check if game is over (Player 0 win)
               if done or truncation:
                     reward = env.reward(done=True, player=0)
                     memory.save2memoryVectEnvs(
                        np.concatenate(
                           (
                                 p0_state,
                                 p1_state,
                                 p0_state_flipped,
                                 p1_state_flipped,
                           )
                        ),
                        [p0_action, p1_action, 6 - p0_action, 6 - p1_action],
                        [
                           reward,
                           LESSON["rewards"]["lose"],
                           reward,
                           LESSON["rewards"]["lose"],
                        ],
                        np.concatenate(
                           (
                                 p0_next_state,
                                 p1_next_state,
                                 p0_next_state_flipped,
                                 p1_next_state_flipped,
                           )
                        ),
                        [done, done, done, done],
                     )
               else:  # Play continues
                     if p1_state is not None:
                        reward = env.reward(done=False, player=1)
                        memory.save2memoryVectEnvs(
                           np.concatenate((p1_state, p1_state_flipped)),
                           [p1_action, 6 - p1_action],
                           [reward, reward],
                           np.concatenate(
                                 (p1_next_state, p1_next_state_flipped)
                           ),
                           [done, done],
                        )

                     # Player 1"s turn
                     p1_action_mask = observation["action_mask"]
                     p1_state = np.moveaxis(
                        observation["observation"], [-1], [-3]
                     )
                     # Swap pieces so that the agent always sees the board from the same perspective
                     p1_state[[0, 1], :, :] = p1_state[[0, 1], :, :]
                     p1_state_flipped = np.expand_dims(np.flip(p1_state, 2), 0)
                     p1_state = np.expand_dims(p1_state, 0)

                     if not opponent_first:
                        if LESSON["opponent"] == "self":
                           p1_action = opponent.getAction(
                                 p1_state, 0, p1_action_mask
                           )[0]
                        elif LESSON["opponent"] == "random":
                           p1_action = opponent.getAction(
                                 p1_action_mask,
                                 p0_action,
                                 LESSON["block_vert_coef"],
                           )
                        else:
                           p1_action = opponent.getAction(player=1)
                     else:
                        p1_action = agent.getAction(
                           p1_state, epsilon, p1_action_mask
                        )[
                           0
                        ]  # Get next action from agent
                        train_actions_hist[p1_action] += 1

                     env.step(p1_action)  # Act in environment
                     observation, env_reward, done, truncation, _ = env.last()
                     p1_next_state = np.moveaxis(
                        observation["observation"], [-1], [-3]
                     )
                     p1_next_state[[0, 1], :, :] = p1_next_state[[0, 1], :, :]
                     p1_next_state_flipped = np.expand_dims(
                        np.flip(p1_next_state, 2), 0
                     )
                     p1_next_state = np.expand_dims(p1_next_state, 0)

                     if opponent_first:
                        score += env_reward
                     turns += 1

                     # Check if game is over (Player 1 win)
                     if done or truncation:
                        reward = env.reward(done=True, player=1)
                        memory.save2memoryVectEnvs(
                           np.concatenate(
                                 (
                                    p0_state,
                                    p1_state,
                                    p0_state_flipped,
                                    p1_state_flipped,
                                 )
                           ),
                           [
                                 p0_action,
                                 p1_action,
                                 6 - p0_action,
                                 6 - p1_action,
                           ],
                           [
                                 LESSON["rewards"]["lose"],
                                 reward,
                                 LESSON["rewards"]["lose"],
                                 reward,
                           ],
                           np.concatenate(
                                 (
                                    p0_next_state,
                                    p1_next_state,
                                    p0_next_state_flipped,
                                    p1_next_state_flipped,
                                 )
                           ),
                           [done, done, done, done],
                        )

                     else:  # Play continues
                        reward = env.reward(done=False, player=0)
                        memory.save2memoryVectEnvs(
                           np.concatenate((p0_state, p0_state_flipped)),
                           [p0_action, 6 - p0_action],
                           [reward, reward],
                           np.concatenate(
                                 (p0_next_state, p0_next_state_flipped)
                           ),
                           [done, done],
                        )

               # Learn according to learning frequency
               if (memory.counter % agent.learn_step == 0) and (
                     len(memory) >= agent.batch_size
               ):
                     # Sample replay buffer
                     # Learn according to agent"s RL algorithm
                     experiences = memory.sample(agent.batch_size)
                     agent.learn(experiences)

               # Stop episode if any agents have terminated
               if done or truncation:
                     break

            total_steps += idx_step + 1
            total_episodes += 1
            turns_per_episode.append(turns)
            # Save the total episode reward
            agent.scores.append(score)

            if LESSON["opponent"] == "self":
               if (total_episodes % LESSON["opponent_upgrade"] == 0) and (
                     (idx_epi + 1) > evo_epochs
               ):
                     elite_opp, _, _ = tournament._elitism(pop)
                     elite_opp.actor.eval()
                     opponent_pool.append(elite_opp)
                     opp_update_counter += 1

         # Update epsilon for exploration
         epsilon = max(eps_end, epsilon * eps_decay)

   mean_turns = np.mean(turns_per_episode)

   # Now evolve population if necessary
   if (idx_epi + 1) % evo_epochs == 0:
         # Evaluate population vs random actions
         fitnesses = []
         win_rates = []
         eval_actions_hist = [0] * action_dim  # Eval actions histogram
         eval_turns = 0  # Eval turns counter
         for agent in pop:
            with torch.no_grad():
               rewards = []
               for i in range(evo_loop):
                     env.reset()  # Reset environment at start of episode
                     observation, reward, done, truncation, _ = env.last()

                     player = -1  # Tracker for which player"s turn it is

                     # Create opponent of desired difficulty
                     opponent = Opponent(env, difficulty=LESSON["eval_opponent"])

                     # Randomly decide whether agent will go first or second
                     if random.random() > 0.5:
                        opponent_first = False
                     else:
                        opponent_first = True

                     score = 0

                     for idx_step in range(max_steps):
                        action_mask = observation["action_mask"]
                        if player < 0:
                           if opponent_first:
                                 if LESSON["eval_opponent"] == "random":
                                    action = opponent.getAction(action_mask)
                                 else:
                                    action = opponent.getAction(player=0)
                           else:
                                 state = np.moveaxis(
                                    observation["observation"], [-1], [-3]
                                 )
                                 state = np.expand_dims(state, 0)
                                 action = agent.getAction(state, 0, action_mask)[
                                    0
                                 ]  # Get next action from agent
                                 eval_actions_hist[action] += 1
                        if player > 0:
                           if not opponent_first:
                                 if LESSON["eval_opponent"] == "random":
                                    action = opponent.getAction(action_mask)
                                 else:
                                    action = opponent.getAction(player=1)
                           else:
                                 state = np.moveaxis(
                                    observation["observation"], [-1], [-3]
                                 )
                                 state[[0, 1], :, :] = state[[0, 1], :, :]
                                 state = np.expand_dims(state, 0)
                                 action = agent.getAction(state, 0, action_mask)[
                                    0
                                 ]  # Get next action from agent
                                 eval_actions_hist[action] += 1

                        env.step(action)  # Act in environment
                        observation, reward, done, truncation, _ = env.last()

                        if (player > 0 and opponent_first) or (
                           player < 0 and not opponent_first
                        ):
                           score += reward

                        eval_turns += 1

                        if done or truncation:
                           break

                        player *= -1

                     rewards.append(score)
            mean_fit = np.mean(rewards)
            agent.fitness.append(mean_fit)
            fitnesses.append(mean_fit)

         eval_turns = eval_turns / len(pop) / evo_loop

         pbar.set_postfix_str(
            f"    Train Mean Score: {np.mean(agent.scores[-episodes_per_epoch:])}   Train Mean Turns: {mean_turns}   Eval Mean Fitness: {np.mean(fitnesses)}   Eval Best Fitness: {np.max(fitnesses)}   Eval Mean Turns: {eval_turns}   Total Steps: {total_steps}"
         )
         pbar.update(0)

         # Format action histograms for visualisation
         train_actions_hist = [
            freq / sum(train_actions_hist) for freq in train_actions_hist
         ]
         eval_actions_hist = [
            freq / sum(eval_actions_hist) for freq in eval_actions_hist
         ]
         train_actions_dict = {
            f"train/action_{index}": action
            for index, action in enumerate(train_actions_hist)
         }
         eval_actions_dict = {
            f"eval/action_{index}": action
            for index, action in enumerate(eval_actions_hist)
         }

         if wb:
            wandb_dict = {
               "global_step": total_steps,
               "train/mean_score": np.mean(agent.scores[-episodes_per_epoch:]),
               "train/mean_turns_per_game": mean_turns,
               "train/epsilon": epsilon,
               "train/opponent_updates": opp_update_counter,
               "eval/mean_fitness": np.mean(fitnesses),
               "eval/best_fitness": np.max(fitnesses),
               "eval/mean_turns_per_game": eval_turns,
            }
            wandb_dict.update(train_actions_dict)
            wandb_dict.update(eval_actions_dict)
            wandb.log(wandb_dict)

         # Tournament selection and population mutation
         elite, pop = tournament.select(pop)
         pop = mutations.mutation(pop)

if max_episodes > 0:
   if wb:
      wandb.finish()

# Save the trained agent
save_path = LESSON["save_path"]
os.makedirs(os.path.dirname(save_path), exist_ok=True)
elite.saveCheckpoint(save_path)
print(f"Elite agent saved to '{save_path}'.")

Trained model weights#

Trained model weights are provided at PettingZoo/tutorials/AgileRL/models. Take a look, train against these models, and see if you can beat them!

Watch the trained agents play#

The following code allows you to load your saved DQN agent from the previous training block, test the agent’s performance, and then visualise a number of episodes as a gif.

Render trained agents
import os

import imageio
import numpy as np
import torch
from agilerl.algorithms.dqn import DQN
from agilerl_dqn_curriculum import Opponent
from PIL import Image, ImageDraw, ImageFont

from pettingzoo.classic import connect_four_v3


# Define function to return image
def _label_with_episode_number(frame, episode_num, frame_no, p):
    im = Image.fromarray(frame)
    drawer = ImageDraw.Draw(im)
    text_color = (255, 255, 255)
    font = ImageFont.truetype("arial.ttf", size=45)
    drawer.text(
        (100, 5),
        f"Episode: {episode_num+1}     Frame: {frame_no}",
        fill=text_color,
        font=font,
    )
    if p == 1:
        player = "Player 1"
        color = (255, 0, 0)
    if p == 2:
        player = "Player 2"
        color = (100, 255, 150)
    if p is None:
        player = "Self-play"
        color = (255, 255, 255)
    drawer.text((700, 5), f"Agent: {player}", fill=color, font=font)
    return im


# Resizes frames to make file size smaller
def resize_frames(frames, fraction):
    resized_frames = []
    for img in frames:
        new_width = int(img.width * fraction)
        new_height = int(img.height * fraction)
        img_resized = img.resize((new_width, new_height))
        resized_frames.append(np.array(img_resized))

    return resized_frames


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    path = "./models/DQN/lesson3_trained_agent.pt"  # Path to saved agent checkpoint

    env = connect_four_v3.env(render_mode="rgb_array")
    env.reset()

    # Configure the algo input arguments
    state_dim = [
        env.observation_space(agent)["observation"].shape for agent in env.agents
    ]
    one_hot = False
    action_dim = [env.action_space(agent).n for agent in env.agents]

    # Pre-process dimensions for pytorch layers
    # We will use self-play, so we only need to worry about the state dim of a single agent
    # We flatten the 6x7x2 observation as input to the agent's neural network
    state_dim = np.zeros(state_dim[0]).flatten().shape
    action_dim = action_dim[0]

    # Instantiate an DQN object
    dqn = DQN(
        state_dim,
        action_dim,
        one_hot,
        device=device,
    )

    # Load the saved algorithm into the DQN object
    dqn.loadCheckpoint(path)

    for opponent_difficulty in ["random", "weak", "strong", "self"]:
        # Create opponent
        if opponent_difficulty == "self":
            opponent = dqn
        else:
            opponent = Opponent(env, opponent_difficulty)

        # Define test loop parameters
        episodes = 2  # Number of episodes to test agent on
        max_steps = (
            500  # Max number of steps to take in the environment in each episode
        )

        rewards = []  # List to collect total episodic reward
        frames = []  # List to collect frames

        print("============================================")
        print(f"Agent: {path}")
        print(f"Opponent: {opponent_difficulty}")

        # Test loop for inference
        for ep in range(episodes):
            if ep / episodes < 0.5:
                opponent_first = False
                p = 1
            else:
                opponent_first = True
                p = 2
            if opponent_difficulty == "self":
                p = None
            env.reset()  # Reset environment at start of episode
            frame = env.render()
            frames.append(
                _label_with_episode_number(frame, episode_num=ep, frame_no=0, p=p)
            )
            observation, reward, done, truncation, _ = env.last()
            player = -1  # Tracker for which player's turn it is
            score = 0
            for idx_step in range(max_steps):
                action_mask = observation["action_mask"]
                if player < 0:
                    state = np.moveaxis(observation["observation"], [-1], [-3])
                    state = np.expand_dims(state, 0)
                    if opponent_first:
                        if opponent_difficulty == "self":
                            action = opponent.getAction(
                                state, epsilon=0, action_mask=action_mask
                            )[0]
                        elif opponent_difficulty == "random":
                            action = opponent.getAction(action_mask)
                        else:
                            action = opponent.getAction(player=0)
                    else:
                        action = dqn.getAction(
                            state, epsilon=0, action_mask=action_mask
                        )[
                            0
                        ]  # Get next action from agent
                if player > 0:
                    state = np.moveaxis(observation["observation"], [-1], [-3])
                    state[[0, 1], :, :] = state[[0, 1], :, :]
                    state = np.expand_dims(state, 0)
                    if not opponent_first:
                        if opponent_difficulty == "self":
                            action = opponent.getAction(
                                state, epsilon=0, action_mask=action_mask
                            )[0]
                        elif opponent_difficulty == "random":
                            action = opponent.getAction(action_mask)
                        else:
                            action = opponent.getAction(player=1)
                    else:
                        action = dqn.getAction(
                            state, epsilon=0, action_mask=action_mask
                        )[
                            0
                        ]  # Get next action from agent
                env.step(action)  # Act in environment
                observation, reward, termination, truncation, _ = env.last()
                # Save the frame for this step and append to frames list
                frame = env.render()
                frames.append(
                    _label_with_episode_number(
                        frame, episode_num=ep, frame_no=idx_step, p=p
                    )
                )

                if (player > 0 and opponent_first) or (
                    player < 0 and not opponent_first
                ):
                    score += reward
                else:
                    score -= reward

                # Stop episode if any agents have terminated
                if truncation or termination:
                    break

                player *= -1

            print("-" * 15, f"Episode: {ep+1}", "-" * 15)
            print(f"Episode length: {idx_step}")
            print(f"Score: {score}")

        print("============================================")

        frames = resize_frames(frames, 0.5)

        # Save the gif to specified path
        gif_path = "./videos/"
        os.makedirs(gif_path, exist_ok=True)
        imageio.mimwrite(
            os.path.join("./videos/", f"connect_four_{opponent_difficulty}_opp.gif"),
            frames,
            duration=400,
            loop=True,
        )

    env.close()

Full training code#

Full training code

Please note that on line 612 max_episodes is set to 10 to allow fast testing of this tutorial code. This line can be deleted, and the line below it uncommented, to use the number of episodes set in the config files.

"""This tutorial shows how to train a DQN agent on the connect four environment, using curriculum learning and self play.

Author: Nick (https://github.com/nicku-a)
"""
import copy
import os
import random
from collections import deque
from datetime import datetime

import numpy as np
import torch
import wandb
import yaml
from agilerl.components.replay_buffer import ReplayBuffer
from agilerl.hpo.mutation import Mutations
from agilerl.hpo.tournament import TournamentSelection
from agilerl.utils.utils import initialPopulation
from tqdm import tqdm, trange

from pettingzoo.classic import connect_four_v3


class CurriculumEnv:
    """Wrapper around environment to modify reward for curriculum learning.

    :param env: Environment to learn in
    :type env: PettingZoo-style environment
    :param lesson: Lesson settings for curriculum learning
    :type lesson: dict
    """

    def __init__(self, env, lesson):
        self.env = env
        self.lesson = lesson

    def fill_replay_buffer(self, memory, opponent):
        """Fill the replay buffer with experiences collected by taking random actions in the environment.

        :param memory: Experience replay buffer
        :type memory: AgileRL experience replay buffer
        """
        print("Filling replay buffer ...")

        pbar = tqdm(total=memory.memory_size)
        while len(memory) < memory.memory_size:
            # Randomly decide whether random player will go first or second
            if random.random() > 0.5:
                opponent_first = False
            else:
                opponent_first = True

            mem_full = len(memory)
            self.reset()  # Reset environment at start of episode
            observation, reward, done, truncation, _ = self.last()

            (
                p1_state,
                p1_state_flipped,
                p1_action,
                p1_next_state,
                p1_next_state_flipped,
            ) = (None, None, None, None, None)
            done, truncation = False, False

            while not (done or truncation):
                # Player 0's turn
                p0_action_mask = observation["action_mask"]
                p0_state = np.moveaxis(observation["observation"], [-1], [-3])
                p0_state_flipped = np.expand_dims(np.flip(p0_state, 2), 0)
                p0_state = np.expand_dims(p0_state, 0)
                if opponent_first:
                    p0_action = self.env.action_space("player_0").sample(p0_action_mask)
                else:
                    if self.lesson["warm_up_opponent"] == "random":
                        p0_action = opponent.getAction(
                            p0_action_mask, p1_action, self.lesson["block_vert_coef"]
                        )
                    else:
                        p0_action = opponent.getAction(player=0)
                self.step(p0_action)  # Act in environment
                observation, env_reward, done, truncation, _ = self.last()
                p0_next_state = np.moveaxis(observation["observation"], [-1], [-3])
                p0_next_state_flipped = np.expand_dims(np.flip(p0_next_state, 2), 0)
                p0_next_state = np.expand_dims(p0_next_state, 0)

                if done or truncation:
                    reward = self.reward(done=True, player=0)
                    memory.save2memoryVectEnvs(
                        np.concatenate(
                            (p0_state, p1_state, p0_state_flipped, p1_state_flipped)
                        ),
                        [p0_action, p1_action, 6 - p0_action, 6 - p1_action],
                        [
                            reward,
                            LESSON["rewards"]["lose"],
                            reward,
                            LESSON["rewards"]["lose"],
                        ],
                        np.concatenate(
                            (
                                p0_next_state,
                                p1_next_state,
                                p0_next_state_flipped,
                                p1_next_state_flipped,
                            )
                        ),
                        [done, done, done, done],
                    )
                else:  # Play continues
                    if p1_state is not None:
                        reward = self.reward(done=False, player=1)
                        memory.save2memoryVectEnvs(
                            np.concatenate((p1_state, p1_state_flipped)),
                            [p1_action, 6 - p1_action],
                            [reward, reward],
                            np.concatenate((p1_next_state, p1_next_state_flipped)),
                            [done, done],
                        )

                    # Player 1's turn
                    p1_action_mask = observation["action_mask"]
                    p1_state = np.moveaxis(observation["observation"], [-1], [-3])
                    p1_state[[0, 1], :, :] = p1_state[[0, 1], :, :]
                    p1_state_flipped = np.expand_dims(np.flip(p1_state, 2), 0)
                    p1_state = np.expand_dims(p1_state, 0)
                    if not opponent_first:
                        p1_action = self.env.action_space("player_1").sample(
                            p1_action_mask
                        )
                    else:
                        if self.lesson["warm_up_opponent"] == "random":
                            p1_action = opponent.getAction(
                                p1_action_mask, p0_action, LESSON["block_vert_coef"]
                            )
                        else:
                            p1_action = opponent.getAction(player=1)
                    self.step(p1_action)  # Act in environment
                    observation, env_reward, done, truncation, _ = self.last()
                    p1_next_state = np.moveaxis(observation["observation"], [-1], [-3])
                    p1_next_state[[0, 1], :, :] = p1_next_state[[0, 1], :, :]
                    p1_next_state_flipped = np.expand_dims(np.flip(p1_next_state, 2), 0)
                    p1_next_state = np.expand_dims(p1_next_state, 0)

                    if done or truncation:
                        reward = self.reward(done=True, player=1)
                        memory.save2memoryVectEnvs(
                            np.concatenate(
                                (p0_state, p1_state, p0_state_flipped, p1_state_flipped)
                            ),
                            [p0_action, p1_action, 6 - p0_action, 6 - p1_action],
                            [
                                LESSON["rewards"]["lose"],
                                reward,
                                LESSON["rewards"]["lose"],
                                reward,
                            ],
                            np.concatenate(
                                (
                                    p0_next_state,
                                    p1_next_state,
                                    p0_next_state_flipped,
                                    p1_next_state_flipped,
                                )
                            ),
                            [done, done, done, done],
                        )

                    else:  # Play continues
                        reward = self.reward(done=False, player=0)
                        memory.save2memoryVectEnvs(
                            np.concatenate((p0_state, p0_state_flipped)),
                            [p0_action, 6 - p0_action],
                            [reward, reward],
                            np.concatenate((p0_next_state, p0_next_state_flipped)),
                            [done, done],
                        )

            pbar.update(len(memory) - mem_full)
        pbar.close()
        print("Replay buffer warmed up.")
        return memory

    def check_winnable(self, lst, piece):
        """Checks if four pieces in a row represent a winnable opportunity, e.g. [1, 1, 1, 0] or [2, 0, 2, 2].

        :param lst: List of pieces in row
        :type lst: List
        :param piece: Player piece we are checking (1 or 2)
        :type piece: int
        """
        return lst.count(piece) == 3 and lst.count(0) == 1

    def check_vertical_win(self, player):
        """Checks if a win is vertical.

        :param player: Player who we are checking, 0 or 1
        :type player: int
        """
        board = np.array(self.env.env.board).reshape(6, 7)
        piece = player + 1

        column_count = 7
        row_count = 6

        # Check vertical locations for win
        for c in range(column_count):
            for r in range(row_count - 3):
                if (
                    board[r][c] == piece
                    and board[r + 1][c] == piece
                    and board[r + 2][c] == piece
                    and board[r + 3][c] == piece
                ):
                    return True
        return False

    def check_three_in_row(self, player):
        """Checks if there are three pieces in a row and a blank space next, or two pieces - blank - piece.

        :param player: Player who we are checking, 0 or 1
        :type player: int
        """
        board = np.array(self.env.env.board).reshape(6, 7)
        piece = player + 1

        # Check horizontal locations
        column_count = 7
        row_count = 6
        three_in_row_count = 0

        # Check vertical locations
        for c in range(column_count):
            for r in range(row_count - 3):
                if self.check_winnable(board[r : r + 4, c].tolist(), piece):
                    three_in_row_count += 1

        # Check horizontal locations
        for r in range(row_count):
            for c in range(column_count - 3):
                if self.check_winnable(board[r, c : c + 4].tolist(), piece):
                    three_in_row_count += 1

        # Check positively sloped diagonals
        for c in range(column_count - 3):
            for r in range(row_count - 3):
                if self.check_winnable(
                    [
                        board[r, c],
                        board[r + 1, c + 1],
                        board[r + 2, c + 2],
                        board[r + 3, c + 3],
                    ],
                    piece,
                ):
                    three_in_row_count += 1

        # Check negatively sloped diagonals
        for c in range(column_count - 3):
            for r in range(3, row_count):
                if self.check_winnable(
                    [
                        board[r, c],
                        board[r - 1, c + 1],
                        board[r - 2, c + 2],
                        board[r - 3, c + 3],
                    ],
                    piece,
                ):
                    three_in_row_count += 1

        return three_in_row_count

    def reward(self, done, player):
        """Processes and returns reward from environment according to lesson criteria.

        :param done: Environment has terminated
        :type done: bool
        :param player: Player who we are checking, 0 or 1
        :type player: int
        """
        if done:
            reward = (
                self.lesson["rewards"]["vertical_win"]
                if self.check_vertical_win(player)
                else self.lesson["rewards"]["win"]
            )
        else:
            agent_three_count = self.check_three_in_row(1 - player)
            opp_three_count = self.check_three_in_row(player)
            if (agent_three_count + opp_three_count) == 0:
                reward = self.lesson["rewards"]["play_continues"]
            else:
                reward = (
                    self.lesson["rewards"]["three_in_row"] * agent_three_count
                    + self.lesson["rewards"]["opp_three_in_row"] * opp_three_count
                )
        return reward

    def last(self):
        """Wrapper around PettingZoo env last method."""
        return self.env.last()

    def step(self, action):
        """Wrapper around PettingZoo env step method."""
        self.env.step(action)

    def reset(self):
        """Wrapper around PettingZoo env reset method."""
        self.env.reset()


class Opponent:
    """Connect 4 opponent to train and/or evaluate against.

    :param env: Environment to learn in
    :type env: PettingZoo-style environment
    :param difficulty: Difficulty level of opponent, 'random', 'weak' or 'strong'
    :type difficulty: str
    """

    def __init__(self, env, difficulty):
        self.env = env.env
        self.difficulty = difficulty
        if self.difficulty == "random":
            self.getAction = self.random_opponent
        elif self.difficulty == "weak":
            self.getAction = self.weak_rule_based_opponent
        else:
            self.getAction = self.strong_rule_based_opponent
        self.num_cols = 7
        self.num_rows = 6
        self.length = 4
        self.top = [0] * self.num_cols

    def update_top(self):
        """Updates self.top, a list which tracks the row on top of the highest piece in each column."""
        board = np.array(self.env.env.board).reshape(self.num_rows, self.num_cols)
        non_zeros = np.where(board != 0)
        rows, cols = non_zeros
        top = np.zeros(board.shape[1], dtype=int)
        for col in range(board.shape[1]):
            column_pieces = rows[cols == col]
            if len(column_pieces) > 0:
                top[col] = np.min(column_pieces) - 1
            else:
                top[col] = 5
        full_columns = np.all(board != 0, axis=0)
        top[full_columns] = 6
        self.top = top

    def random_opponent(self, action_mask, last_opp_move=None, block_vert_coef=1):
        """Takes move for random opponent. If the lesson aims to randomly block vertical wins with a higher probability, this is done here too.

        :param action_mask: Mask of legal actions: 1=legal, 0=illegal
        :type action_mask: List
        :param last_opp_move: Most recent action taken by agent against this opponent
        :type last_opp_move: int
        :param block_vert_coef: How many times more likely to block vertically
        :type block_vert_coef: float
        """
        if last_opp_move is not None:
            action_mask[last_opp_move] *= block_vert_coef
        action = random.choices(list(range(self.num_cols)), action_mask)[0]
        return action

    def weak_rule_based_opponent(self, player):
        """Takes move for weak rule-based opponent.

        :param player: Player who we are checking, 0 or 1
        :type player: int
        """
        self.update_top()
        max_length = -1
        best_actions = []
        for action in range(self.num_cols):
            possible, reward, ended, lengths = self.outcome(
                action, player, return_length=True
            )
            if possible and lengths.sum() > max_length:
                best_actions = []
                max_length = lengths.sum()
            if possible and lengths.sum() == max_length:
                best_actions.append(action)
        best_action = random.choice(best_actions)
        return best_action

    def strong_rule_based_opponent(self, player):
        """Takes move for strong rule-based opponent.

        :param player: Player who we are checking, 0 or 1
        :type player: int
        """
        self.update_top()

        winning_actions = []
        for action in range(self.num_cols):
            possible, reward, ended = self.outcome(action, player)
            if possible and ended:
                winning_actions.append(action)
        if len(winning_actions) > 0:
            winning_action = random.choice(winning_actions)
            return winning_action

        opp = 1 if player == 0 else 0
        loss_avoiding_actions = []
        for action in range(self.num_cols):
            possible, reward, ended = self.outcome(action, opp)
            if possible and ended:
                loss_avoiding_actions.append(action)
        if len(loss_avoiding_actions) > 0:
            loss_avoiding_action = random.choice(loss_avoiding_actions)
            return loss_avoiding_action

        return self.weak_rule_based_opponent(player)  # take best possible move

    def outcome(self, action, player, return_length=False):
        """Takes move for weak rule-based opponent.

        :param action: Action to take in environment
        :type action: int
        :param player: Player who we are checking, 0 or 1
        :type player: int
        :param return_length: Return length of outcomes, defaults to False
        :type player: bool, optional
        """
        if not (self.top[action] < self.num_rows):  # action column is full
            return (False, None, None) + ((None,) if return_length else ())

        row, col = self.top[action], action
        piece = player + 1

        # down, up, left, right, down-left, up-right, down-right, up-left,
        directions = np.array(
            [
                [[-1, 0], [1, 0]],
                [[0, -1], [0, 1]],
                [[-1, -1], [1, 1]],
                [[-1, 1], [1, -1]],
            ]
        )  # |4x2x2|

        positions = np.array([row, col]).reshape(1, 1, 1, -1) + np.expand_dims(
            directions, -2
        ) * np.arange(1, self.length).reshape(
            1, 1, -1, 1
        )  # |4x2x3x2|
        valid_positions = np.logical_and(
            np.logical_and(
                positions[:, :, :, 0] >= 0, positions[:, :, :, 0] < self.num_rows
            ),
            np.logical_and(
                positions[:, :, :, 1] >= 0, positions[:, :, :, 1] < self.num_cols
            ),
        )  # |4x2x3|
        d0 = np.where(valid_positions, positions[:, :, :, 0], 0)
        d1 = np.where(valid_positions, positions[:, :, :, 1], 0)
        board = np.array(self.env.env.board).reshape(self.num_rows, self.num_cols)
        board_values = np.where(valid_positions, board[d0, d1], 0)
        a = (board_values == piece).astype(int)
        b = np.concatenate(
            (a, np.zeros_like(a[:, :, :1])), axis=-1
        )  # padding with zeros to compute length
        lengths = np.argmin(b, -1)

        ended = False
        # check if winnable in any direction
        for both_dir in board_values:
            # |2x3|
            line = np.concatenate((both_dir[0][::-1], [piece], both_dir[1]))
            if "".join(map(str, [piece] * self.length)) in "".join(map(str, line)):
                ended = True
                break

        # ended = np.any(np.greater_equal(np.sum(lengths, 1), self.length - 1))
        draw = True
        for c, v in enumerate(self.top):
            draw &= (v == self.num_rows) if c != col else (v == (self.num_rows - 1))
        ended |= draw
        reward = (-1) ** (player) if ended and not draw else 0

        return (True, reward, ended) + ((lengths,) if return_length else ())


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("===== AgileRL Curriculum Learning Demo =====")

    for lesson_number in range(1, 5):
        # Load lesson for curriculum
        with open(f"./curriculums/connect_four/lesson{lesson_number}.yaml") as file:
            LESSON = yaml.safe_load(file)

        # Define the network configuration
        NET_CONFIG = {
            "arch": "cnn",  # Network architecture
            "h_size": [64, 64],  # Actor hidden size
            "c_size": [128],  # CNN channel size
            "k_size": [4],  # CNN kernel size
            "s_size": [1],  # CNN stride size
            "normalize": False,  # Normalize image from range [0,255] to [0,1]
        }

        # Define the initial hyperparameters
        INIT_HP = {
            "POPULATION_SIZE": 6,
            # "ALGO": "Rainbow DQN",  # Algorithm
            "ALGO": "DQN",  # Algorithm
            "DOUBLE": True,
            # Swap image channels dimension from last to first [H, W, C] -> [C, H, W]
            "BATCH_SIZE": 256,  # Batch size
            "LR": 1e-4,  # Learning rate
            "GAMMA": 0.99,  # Discount factor
            "MEMORY_SIZE": 100000,  # Max memory buffer size
            "LEARN_STEP": 1,  # Learning frequency
            "N_STEP": 1,  # Step number to calculate td error
            "PER": False,  # Use prioritized experience replay buffer
            "ALPHA": 0.6,  # Prioritized replay buffer parameter
            "TAU": 0.01,  # For soft update of target parameters
            "BETA": 0.4,  # Importance sampling coefficient
            "PRIOR_EPS": 0.000001,  # Minimum priority for sampling
            "NUM_ATOMS": 51,  # Unit number of support
            "V_MIN": 0.0,  # Minimum value of support
            "V_MAX": 200.0,  # Maximum value of support
            "WANDB": False,  # Use Weights and Biases tracking
        }

        # Define the connect four environment
        env = connect_four_v3.env()
        env.reset()

        # Configure the algo input arguments
        state_dim = [
            env.observation_space(agent)["observation"].shape for agent in env.agents
        ]
        one_hot = False
        action_dim = [env.action_space(agent).n for agent in env.agents]
        INIT_HP["DISCRETE_ACTIONS"] = True
        INIT_HP["MAX_ACTION"] = None
        INIT_HP["MIN_ACTION"] = None

        # Warp the environment in the curriculum learning wrapper
        env = CurriculumEnv(env, LESSON)

        # Pre-process dimensions for PyTorch layers
        # We only need to worry about the state dim of a single agent
        # We flatten the 6x7x2 observation as input to the agent"s neural network
        state_dim = np.moveaxis(np.zeros(state_dim[0]), [-1], [-3]).shape
        action_dim = action_dim[0]

        # Create a population ready for evolutionary hyper-parameter optimisation
        pop = initialPopulation(
            INIT_HP["ALGO"],
            state_dim,
            action_dim,
            one_hot,
            NET_CONFIG,
            INIT_HP,
            population_size=INIT_HP["POPULATION_SIZE"],
            device=device,
        )

        # Configure the replay buffer
        field_names = ["state", "action", "reward", "next_state", "done"]
        memory = ReplayBuffer(
            action_dim=action_dim,  # Number of agent actions
            memory_size=INIT_HP["MEMORY_SIZE"],  # Max replay buffer size
            field_names=field_names,  # Field names to store in memory
            device=device,
        )

        # Instantiate a tournament selection object (used for HPO)
        tournament = TournamentSelection(
            tournament_size=2,  # Tournament selection size
            elitism=True,  # Elitism in tournament selection
            population_size=INIT_HP["POPULATION_SIZE"],  # Population size
            evo_step=1,
        )  # Evaluate using last N fitness scores

        # Instantiate a mutations object (used for HPO)
        mutations = Mutations(
            algo=INIT_HP["ALGO"],
            no_mutation=0.2,  # Probability of no mutation
            architecture=0,  # Probability of architecture mutation
            new_layer_prob=0.2,  # Probability of new layer mutation
            parameters=0.2,  # Probability of parameter mutation
            activation=0,  # Probability of activation function mutation
            rl_hp=0.2,  # Probability of RL hyperparameter mutation
            rl_hp_selection=[
                "lr",
                "learn_step",
                "batch_size",
            ],  # RL hyperparams selected for mutation
            mutation_sd=0.1,  # Mutation strength
            # Define search space for each hyperparameter
            min_lr=0.0001,
            max_lr=0.01,
            min_learn_step=1,
            max_learn_step=120,
            min_batch_size=8,
            max_batch_size=64,
            arch=NET_CONFIG["arch"],  # MLP or CNN
            rand_seed=1,
            device=device,
        )

        # Define training loop parameters
        episodes_per_epoch = 10

        # ! NOTE: Uncomment the max_episodes line below to change the number of training episodes. ! #
        # It is deliberately set low to allow testing to ensure this tutorial is sound.
        max_episodes = 10
        # max_episodes = LESSON["max_train_episodes"]  # Total episodes

        max_steps = 500  # Maximum steps to take in each episode
        evo_epochs = 20  # Evolution frequency
        evo_loop = 50  # Number of evaluation episodes
        elite = pop[0]  # Assign a placeholder "elite" agent
        epsilon = 1.0  # Starting epsilon value
        eps_end = 0.1  # Final epsilon value
        eps_decay = 0.9998  # Epsilon decays
        opp_update_counter = 0
        wb = INIT_HP["WANDB"]

        if LESSON["pretrained_path"] is not None:
            for agent in pop:
                # Load pretrained checkpoint
                agent.loadCheckpoint(LESSON["pretrained_path"])
                # Reinit optimizer for new task
                agent.lr = INIT_HP["LR"]
                agent.optimizer = torch.optim.Adam(
                    agent.actor.parameters(), lr=agent.lr
                )

        if LESSON["opponent"] == "self":
            # Create initial pool of opponents
            opponent_pool = deque(maxlen=LESSON["opponent_pool_size"])
            for _ in range(LESSON["opponent_pool_size"]):
                opp = copy.deepcopy(pop[0])
                opp.actor.load_state_dict(pop[0].actor.state_dict())
                opp.actor.eval()
                opponent_pool.append(opp)

        # Perform buffer and agent warmups if desired
        if LESSON["buffer_warm_up"]:
            warm_up_opponent = Opponent(env, difficulty=LESSON["warm_up_opponent"])
            memory = env.fill_replay_buffer(
                memory, warm_up_opponent
            )  # Fill replay buffer with transitions
            if LESSON["agent_warm_up"] > 0:
                print("Warming up agents ...")
                agent = pop[0]
                # Train on randomly collected samples
                for epoch in trange(LESSON["agent_warm_up"]):
                    experiences = memory.sample(agent.batch_size)
                    agent.learn(experiences)
                pop = [agent.clone() for _ in pop]
                elite = agent
                print("Agent population warmed up.")

        if max_episodes > 0:
            if wb:
                wandb.init(
                    # set the wandb project where this run will be logged
                    project="AgileRL",
                    name="{}-EvoHPO-{}-{}Opposition-CNN-{}".format(
                        "connect_four_v3",
                        INIT_HP["ALGO"],
                        LESSON["opponent"],
                        datetime.now().strftime("%m%d%Y%H%M%S"),
                    ),
                    # track hyperparameters and run metadata
                    config={
                        "algo": "Evo HPO Rainbow DQN",
                        "env": "connect_four_v3",
                        "INIT_HP": INIT_HP,
                        "lesson": LESSON,
                    },
                )

        total_steps = 0
        total_episodes = 0
        pbar = trange(int(max_episodes / episodes_per_epoch))

        # Training loop
        for idx_epi in pbar:
            turns_per_episode = []
            train_actions_hist = [0] * action_dim
            for agent in pop:  # Loop through population
                for episode in range(episodes_per_epoch):
                    env.reset()  # Reset environment at start of episode
                    observation, env_reward, done, truncation, _ = env.last()

                    (
                        p1_state,
                        p1_state_flipped,
                        p1_action,
                        p1_next_state,
                        p1_next_state_flipped,
                    ) = (None, None, None, None, None)

                    if LESSON["opponent"] == "self":
                        # Randomly choose opponent from opponent pool if using self-play
                        opponent = random.choice(opponent_pool)
                    else:
                        # Create opponent of desired difficulty
                        opponent = Opponent(env, difficulty=LESSON["opponent"])

                    # Randomly decide whether agent will go first or second
                    if random.random() > 0.5:
                        opponent_first = False
                    else:
                        opponent_first = True

                    score = 0
                    turns = 0  # Number of turns counter

                    for idx_step in range(max_steps):
                        # Player 0"s turn
                        p0_action_mask = observation["action_mask"]
                        p0_state = np.moveaxis(observation["observation"], [-1], [-3])
                        p0_state_flipped = np.expand_dims(np.flip(p0_state, 2), 0)
                        p0_state = np.expand_dims(p0_state, 0)

                        if opponent_first:
                            if LESSON["opponent"] == "self":
                                p0_action = opponent.getAction(
                                    p0_state, 0, p0_action_mask
                                )[0]
                            elif LESSON["opponent"] == "random":
                                p0_action = opponent.getAction(
                                    p0_action_mask, p1_action, LESSON["block_vert_coef"]
                                )
                            else:
                                p0_action = opponent.getAction(player=0)
                        else:
                            p0_action = agent.getAction(
                                p0_state, epsilon, p0_action_mask
                            )[
                                0
                            ]  # Get next action from agent
                            train_actions_hist[p0_action] += 1

                        env.step(p0_action)  # Act in environment
                        observation, env_reward, done, truncation, _ = env.last()
                        p0_next_state = np.moveaxis(
                            observation["observation"], [-1], [-3]
                        )
                        p0_next_state_flipped = np.expand_dims(
                            np.flip(p0_next_state, 2), 0
                        )
                        p0_next_state = np.expand_dims(p0_next_state, 0)

                        if not opponent_first:
                            score += env_reward
                        turns += 1

                        # Check if game is over (Player 0 win)
                        if done or truncation:
                            reward = env.reward(done=True, player=0)
                            memory.save2memoryVectEnvs(
                                np.concatenate(
                                    (
                                        p0_state,
                                        p1_state,
                                        p0_state_flipped,
                                        p1_state_flipped,
                                    )
                                ),
                                [p0_action, p1_action, 6 - p0_action, 6 - p1_action],
                                [
                                    reward,
                                    LESSON["rewards"]["lose"],
                                    reward,
                                    LESSON["rewards"]["lose"],
                                ],
                                np.concatenate(
                                    (
                                        p0_next_state,
                                        p1_next_state,
                                        p0_next_state_flipped,
                                        p1_next_state_flipped,
                                    )
                                ),
                                [done, done, done, done],
                            )
                        else:  # Play continues
                            if p1_state is not None:
                                reward = env.reward(done=False, player=1)
                                memory.save2memoryVectEnvs(
                                    np.concatenate((p1_state, p1_state_flipped)),
                                    [p1_action, 6 - p1_action],
                                    [reward, reward],
                                    np.concatenate(
                                        (p1_next_state, p1_next_state_flipped)
                                    ),
                                    [done, done],
                                )

                            # Player 1"s turn
                            p1_action_mask = observation["action_mask"]
                            p1_state = np.moveaxis(
                                observation["observation"], [-1], [-3]
                            )
                            # Swap pieces so that the agent always sees the board from the same perspective
                            p1_state[[0, 1], :, :] = p1_state[[0, 1], :, :]
                            p1_state_flipped = np.expand_dims(np.flip(p1_state, 2), 0)
                            p1_state = np.expand_dims(p1_state, 0)

                            if not opponent_first:
                                if LESSON["opponent"] == "self":
                                    p1_action = opponent.getAction(
                                        p1_state, 0, p1_action_mask
                                    )[0]
                                elif LESSON["opponent"] == "random":
                                    p1_action = opponent.getAction(
                                        p1_action_mask,
                                        p0_action,
                                        LESSON["block_vert_coef"],
                                    )
                                else:
                                    p1_action = opponent.getAction(player=1)
                            else:
                                p1_action = agent.getAction(
                                    p1_state, epsilon, p1_action_mask
                                )[
                                    0
                                ]  # Get next action from agent
                                train_actions_hist[p1_action] += 1

                            env.step(p1_action)  # Act in environment
                            observation, env_reward, done, truncation, _ = env.last()
                            p1_next_state = np.moveaxis(
                                observation["observation"], [-1], [-3]
                            )
                            p1_next_state[[0, 1], :, :] = p1_next_state[[0, 1], :, :]
                            p1_next_state_flipped = np.expand_dims(
                                np.flip(p1_next_state, 2), 0
                            )
                            p1_next_state = np.expand_dims(p1_next_state, 0)

                            if opponent_first:
                                score += env_reward
                            turns += 1

                            # Check if game is over (Player 1 win)
                            if done or truncation:
                                reward = env.reward(done=True, player=1)
                                memory.save2memoryVectEnvs(
                                    np.concatenate(
                                        (
                                            p0_state,
                                            p1_state,
                                            p0_state_flipped,
                                            p1_state_flipped,
                                        )
                                    ),
                                    [
                                        p0_action,
                                        p1_action,
                                        6 - p0_action,
                                        6 - p1_action,
                                    ],
                                    [
                                        LESSON["rewards"]["lose"],
                                        reward,
                                        LESSON["rewards"]["lose"],
                                        reward,
                                    ],
                                    np.concatenate(
                                        (
                                            p0_next_state,
                                            p1_next_state,
                                            p0_next_state_flipped,
                                            p1_next_state_flipped,
                                        )
                                    ),
                                    [done, done, done, done],
                                )

                            else:  # Play continues
                                reward = env.reward(done=False, player=0)
                                memory.save2memoryVectEnvs(
                                    np.concatenate((p0_state, p0_state_flipped)),
                                    [p0_action, 6 - p0_action],
                                    [reward, reward],
                                    np.concatenate(
                                        (p0_next_state, p0_next_state_flipped)
                                    ),
                                    [done, done],
                                )

                        # Learn according to learning frequency
                        if (memory.counter % agent.learn_step == 0) and (
                            len(memory) >= agent.batch_size
                        ):
                            # Sample replay buffer
                            # Learn according to agent"s RL algorithm
                            experiences = memory.sample(agent.batch_size)
                            agent.learn(experiences)

                        # Stop episode if any agents have terminated
                        if done or truncation:
                            break

                    total_steps += idx_step + 1
                    total_episodes += 1
                    turns_per_episode.append(turns)
                    # Save the total episode reward
                    agent.scores.append(score)

                    if LESSON["opponent"] == "self":
                        if (total_episodes % LESSON["opponent_upgrade"] == 0) and (
                            (idx_epi + 1) > evo_epochs
                        ):
                            elite_opp, _, _ = tournament._elitism(pop)
                            elite_opp.actor.eval()
                            opponent_pool.append(elite_opp)
                            opp_update_counter += 1

                # Update epsilon for exploration
                epsilon = max(eps_end, epsilon * eps_decay)

            mean_turns = np.mean(turns_per_episode)

            # Now evolve population if necessary
            if (idx_epi + 1) % evo_epochs == 0:
                # Evaluate population vs random actions
                fitnesses = []
                win_rates = []
                eval_actions_hist = [0] * action_dim  # Eval actions histogram
                eval_turns = 0  # Eval turns counter
                for agent in pop:
                    with torch.no_grad():
                        rewards = []
                        for i in range(evo_loop):
                            env.reset()  # Reset environment at start of episode
                            observation, reward, done, truncation, _ = env.last()

                            player = -1  # Tracker for which player"s turn it is

                            # Create opponent of desired difficulty
                            opponent = Opponent(env, difficulty=LESSON["eval_opponent"])

                            # Randomly decide whether agent will go first or second
                            if random.random() > 0.5:
                                opponent_first = False
                            else:
                                opponent_first = True

                            score = 0

                            for idx_step in range(max_steps):
                                action_mask = observation["action_mask"]
                                if player < 0:
                                    if opponent_first:
                                        if LESSON["eval_opponent"] == "random":
                                            action = opponent.getAction(action_mask)
                                        else:
                                            action = opponent.getAction(player=0)
                                    else:
                                        state = np.moveaxis(
                                            observation["observation"], [-1], [-3]
                                        )
                                        state = np.expand_dims(state, 0)
                                        action = agent.getAction(state, 0, action_mask)[
                                            0
                                        ]  # Get next action from agent
                                        eval_actions_hist[action] += 1
                                if player > 0:
                                    if not opponent_first:
                                        if LESSON["eval_opponent"] == "random":
                                            action = opponent.getAction(action_mask)
                                        else:
                                            action = opponent.getAction(player=1)
                                    else:
                                        state = np.moveaxis(
                                            observation["observation"], [-1], [-3]
                                        )
                                        state[[0, 1], :, :] = state[[0, 1], :, :]
                                        state = np.expand_dims(state, 0)
                                        action = agent.getAction(state, 0, action_mask)[
                                            0
                                        ]  # Get next action from agent
                                        eval_actions_hist[action] += 1

                                env.step(action)  # Act in environment
                                observation, reward, done, truncation, _ = env.last()

                                if (player > 0 and opponent_first) or (
                                    player < 0 and not opponent_first
                                ):
                                    score += reward

                                eval_turns += 1

                                if done or truncation:
                                    break

                                player *= -1

                            rewards.append(score)
                    mean_fit = np.mean(rewards)
                    agent.fitness.append(mean_fit)
                    fitnesses.append(mean_fit)

                eval_turns = eval_turns / len(pop) / evo_loop

                pbar.set_postfix_str(
                    f"    Train Mean Score: {np.mean(agent.scores[-episodes_per_epoch:])}   Train Mean Turns: {mean_turns}   Eval Mean Fitness: {np.mean(fitnesses)}   Eval Best Fitness: {np.max(fitnesses)}   Eval Mean Turns: {eval_turns}   Total Steps: {total_steps}"
                )
                pbar.update(0)

                # Format action histograms for visualisation
                train_actions_hist = [
                    freq / sum(train_actions_hist) for freq in train_actions_hist
                ]
                eval_actions_hist = [
                    freq / sum(eval_actions_hist) for freq in eval_actions_hist
                ]
                train_actions_dict = {
                    f"train/action_{index}": action
                    for index, action in enumerate(train_actions_hist)
                }
                eval_actions_dict = {
                    f"eval/action_{index}": action
                    for index, action in enumerate(eval_actions_hist)
                }

                if wb:
                    wandb_dict = {
                        "global_step": total_steps,
                        "train/mean_score": np.mean(agent.scores[-episodes_per_epoch:]),
                        "train/mean_turns_per_game": mean_turns,
                        "train/epsilon": epsilon,
                        "train/opponent_updates": opp_update_counter,
                        "eval/mean_fitness": np.mean(fitnesses),
                        "eval/best_fitness": np.max(fitnesses),
                        "eval/mean_turns_per_game": eval_turns,
                    }
                    wandb_dict.update(train_actions_dict)
                    wandb_dict.update(eval_actions_dict)
                    wandb.log(wandb_dict)

                # Tournament selection and population mutation
                elite, pop = tournament.select(pop)
                pop = mutations.mutation(pop)

        if max_episodes > 0:
            if wb:
                wandb.finish()

        # Save the trained agent
        save_path = LESSON["save_path"]
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        elite.saveCheckpoint(save_path)
        print(f"Elite agent saved to '{save_path}'.")