Saliency Generation with Deep Reinforcement Learning

Introduction

This notebook shows usage of the xaitk-saliency API to gain insight into the behavior of a trained deep reinforcement learning agent in an Atari 2600 environment.

This example is based on this paper and the corresponding code in this Github repository. The authors use the Asynchronous Advantage Actor Critic (A3C) algorithm with a LSTM-CNN policy network to train several agents for automated gameplay in different Atari 2600 games. They also implement a method for generating saliency maps using image perturbation. Here, we will try to reproduce their results using the xaitk-saliency API, focusing on the Breakout-v0 environment.

Table of Contents

References

  1. Greydanus, Samuel, et al. “Visualizing and Understanding Atari Agents.” International Conference on Machine Learning. PMLR, 2018.


To run this notebook in Colab, use the link below:

Open In Colab

Set Up Environment

Note for Colab users: after setting up the environment, you may need to “Restart Runtime” in order to resolve package version conflicts (see the README for more info).

import sys  # noqa

!{sys.executable} -m pip install -qU pip
!{sys.executable} -m pip install -q xaitk-saliency
!{sys.executable} -m pip install -q "torch>=1.9.0"
!{sys.executable} -m pip install -q "torchvision>=0.10.0"
!{sys.executable} -m pip install -q "gym[atari,accept-rom-license]==0.26.2"
# Install the libGL-less version of OpenCV.
!{sys.executable} -m pip uninstall -qy "opencv-python" "opencv-python-headless"  # make sure they're both gone.
!{sys.executable} -m pip install -q "opencv-python-headless"

# Force reinstallation to resolve incompatibilities between dependencies
!{sys.executable} -m pip install -q --force-reinstall --no-cache-dir \
    "numpy<2.0" "packaging<24.0" \
    opencv-python-headless scikit-image

Create Atari Environment

Here we create the Breakout environment for our agent using Gym.

Our agent has 4 different actions to choose from:
    1. Do nothing
    2. Fire
    3. Move right
    4. Move left

import gym

env_name = "ALE/Breakout-v5"
env = gym.make(env_name)
# Set seed for reproducibility
env.seed(1)

action_space = env.unwrapped.get_action_meanings()
print(f"Action space: {action_space}")
Action space: ['NOOP', 'FIRE', 'RIGHT', 'LEFT']

Define Policy Network

This policy network implementation is taken directly from the author’s original implementation. It consists of four convolutional layers, an LSTM layer, and two separate fully-connected layers for the value and policy function predictions.

import glob
from typing import Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as functional


class NNPolicy(torch.nn.Module):
    """an actor-critic neural network."""

    def __init__(self, channels: int, num_actions: int) -> None:
        """Initialize model."""
        super().__init__()
        self.conv1 = nn.Conv2d(channels, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.lstm = nn.LSTMCell(32 * 5 * 5, 256)
        self.critic_linear, self.actor_linear = nn.Linear(256, 1), nn.Linear(256, num_actions)

    def forward(
        self,
        inputs: tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
    ) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        """Move game forward ones step."""
        inputs, (hx, cx) = inputs
        x = functional.elu(self.conv1(inputs))
        x = functional.elu(self.conv2(x))
        x = functional.elu(self.conv3(x))
        x = functional.elu(self.conv4(x))
        x = x.view(-1, 32 * 5 * 5)
        hx, cx = self.lstm(x, (hx, cx))
        return self.critic_linear(hx), self.actor_linear(hx), (hx, cx)

    def try_load(self, save_dir: str, checkpoint: str = "*.tar") -> int:
        """."""
        paths = glob.glob(save_dir + checkpoint)
        step = 0
        if len(paths) > 0:
            ckpts = [int(s.split(".")[-2]) for s in paths]
            ix = np.argmax(ckpts)
            step = ckpts[ix]
            self.load_state_dict(torch.load(paths[ix]))
        print("\tno saved models") if step == 0 else print(f"\tloaded model: {paths[ix]}")
        return step

Download Pretrained Model

The authors provide pretrained agents here for the different environments they used (both the “strong” and “overfit” variants, see the original paper for details). We will use the “strong” Breakout agent for our purposes.

We provide the Breakout model checkpoint here for easy download, although we refer the reader to the link above which contains a zip file of all the pretrained agents, including both “strong” and “overfit” agents.

import os
import urllib.request

root_dir = "data/pretrained_agents/breakout-v0/"
os.makedirs(root_dir, exist_ok=True)
model_checkpoint = os.path.join(root_dir, "strong.40.tar")

_ = urllib.request.urlretrieve(
    "https://data.kitware.com/api/v1/item/62325a1d4acac99f426f21e9/download",
    model_checkpoint,
)

Load Pretrained Model

Here we load the downloaded model into an instance of the policy function class.

model = NNPolicy(channels=1, num_actions=env.action_space.n)
model.try_load(root_dir, checkpoint="*.tar")
# Seed seed for reproducibility
_ = torch.manual_seed(1)
	loaded model: data/pretrained_agents/breakout-v0/strong.40.tar

Define Rollout Function

This function carries out the pretrained agent’s policy for a defined number of frames in our Breakout environment. At each step, the current game frame is run through our policy model to get the predicted best action and the agent takes that action. The state of the game is stored after after each step.

import cv2


def prepro(img: np.ndarray) -> np.ndarray:
    """Standard frame pre-processing based on Greydanus et al., '18'."""
    return cv2.resize(src=img[35:195].mean(2), dsize=(80, 80)).astype(np.float32) / 255.0


def rollout(model: torch.nn.Module, env: gym.core.Env, max_ep_len: int) -> dict[str, Any]:
    """Plays the game and return the game history."""
    history = {"ins": [], "logits": [], "values": [], "outs": [], "hx": [], "cx": []}

    state = torch.Tensor(prepro(env.reset()[0]))  # get first state
    episode_length, epr, _eploss, done = 0, 0, 0, False  # bookkeeping
    hx, cx = torch.zeros(1, 256), torch.zeros(1, 256)

    # iterate through each frame in episode
    while not done and episode_length <= max_ep_len:
        episode_length += 1

        # get game state
        model_inp = (state.view(1, 1, 80, 80), (hx, cx))

        # run through model
        value, logit, (hx, cx) = model(model_inp)
        hx, cx = hx.data, cx.data

        # action probabilities
        prob = functional.softmax(logit)

        # best action
        action = prob.max(1)[1].data

        # take best action
        obs, reward, done, _expert_policy, _ = env.step(action.numpy()[0])

        # update reward
        state = torch.Tensor(prepro(obs))
        epr += reward

        # save state
        history["ins"].append(obs)  # game state after taking action
        history["hx"].append(hx.squeeze(0).data.numpy())  # LSTM hx output
        history["cx"].append(cx.squeeze(0).data.numpy())  # LSTM cx output
        history["logits"].append(logit.data.numpy()[0])  # actor output
        history["values"].append(value.data.numpy()[0])  # critic output
        history["outs"].append(prob.data.numpy()[0])  # action probabilities
        print(f"\tstep # {episode_length}, reward {epr:.0f}", end="\r")

    return history

Play Breakout

Our pretrained agent will now play the game for up to 3,000 frames. We will create a short video clip from a slice of the game state so we can see the agent in action. The agent shows a “tunneling” behavior, which was also noted by the original authors in the paper referenced above.

import warnings

import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

# Play game
print("Rolling out policy...")
history = rollout(model, env, max_ep_len=int(3e3))

# Create video from frames
print("\nCreating video...")
fps = 30
w = history["ins"][0].shape[1]
h = history["ins"][0].shape[0]
out = cv2.VideoWriter("data/breakout.mp4", cv2.VideoWriter_fourcc(*"vp09"), fps, (w, h))

start_frame = 638
end_frame = 1238
for i in range(start_frame, end_frame + 1):
    frame = cv2.cvtColor(history["ins"][i], cv2.COLOR_RGB2BGR)  # Convert to BRG for cv2 standards
    out.write(frame)
out.release()

print("Done")
Rolling out policy...
	step # 2179, reward 258
Creating video...
Done

Here we show a single frame from the game.

f = plt.figure(figsize=[8, 12])
frame_ix = 1070
plt.imshow(history["ins"][frame_ix])
for a in f.axes:
    a.get_xaxis().set_visible(False)
    a.get_yaxis().set_visible(False)
plt.show()
../_images/e4a70ef490d4b097ae63e2cf7b413738db6c20584a0a74b0bf7e70e56e413e6e.png

Here we show a recorded video from the game. This video is saved as breakout.mp4 in the current directory.

Note that this file will not exist, and the video will not be visible, until the notebook is run.

%%HTML
<div align="middle">
<video width="50%" controls>
      <source src="data/breakout.mp4" type="video/mp4">
</video></div>

Defining the Application

Our saliency application has four parameters:
    start_frame - the first frame to perform saliency generation for
    end_frame - the last frame to perform saliency generation for
    perturber - the PerturbImage implementation to use
    saliency_gen - the GenerateClassifierConficdenceSaliency implementation to use

The application will create saliency maps for both the actor (policy function) and the critic (value function) for each frame from start_frame to end_frame using the image perturber and saliency generator that you pass it. Salient parts of each frame will be highlighted in blue for the actor and in red for the critic and a video will be created from this set of highlighted game frames. This gives a single representation for each frame of where both models are looking to make their predictions.

To speed up this process, the application utilizes multiple processing threads, one for each frame. This will use a good amount of system memory so be cautious.

import concurrent.futures
import multiprocessing
import threading

from scipy.ndimage.filters import gaussian_filter

from xaitk_saliency import GenerateClassifierConfidenceSaliency, PerturbImage
from xaitk_saliency.utils.masking import occlude_image_batch

actor_sal_maps = []
critic_sal_maps = []

print_lock = threading.Lock()  # Lock to control thread printing


def print_helper(to_print: str, verbose: int) -> None:
    """."""
    if verbose:
        with print_lock:
            print(to_print)


def gen_sal_maps(
    img_idx: int,
    sal_idx: int,
    perturber: PerturbImage,
    saliency_gen: GenerateClassifierConfidenceSaliency,
    verbose: int = 0,
) -> None:
    """Generates actor and critic saliency maps on a per-frame basis given an image
    perturber and a saliency map generation implementation.

    img_idx: index of input video frame
    sal_idx: index into corresponding global saliency map arrays
    perturber: PerturbImage implementation
    saliency_gen: GenerateClassifierConfidenceSaliency implementation
    verbose: whether or not to print debug statements (default: 0)
    """
    global actor_sal_maps
    global critic_sal_maps

    # Score reference frame
    print_helper(f"[{img_idx}]Scoring frame", verbose)

    ref_img = history["ins"][img_idx]
    ref_img_proc = prepro(ref_img)
    hx = torch.tensor(history["hx"][img_idx - 1]).view(1, -1)
    cx = torch.tensor(history["cx"][img_idx - 1]).view(1, -1)

    ref_img_tensor = torch.tensor(ref_img_proc.reshape(1, 1, 80, 80))
    model_inp = (ref_img_tensor, (hx, cx))
    ref_value, ref_logit, _ = model(model_inp)
    ref_value = ref_value.detach().numpy()[0]
    ref_logit = ref_logit.detach().numpy()[0]

    # Get image perturbations
    print_helper(f"[{img_idx}]Perturbing image", verbose)

    pert_masks = perturber(ref_img_proc)
    blurred_img = gaussian_filter(ref_img_proc, sigma=3)
    pert_imgs = occlude_image_batch(ref_img_proc, pert_masks, fill=blurred_img)

    pert_values = []
    pert_logits = []

    # Score perturbed frames
    print_helper(f"[{img_idx}]Scoring perturbations", verbose)

    for pert_img in pert_imgs:
        pert_img_tensor = torch.tensor(pert_img.reshape(1, 1, 80, 80))
        model_inp = (pert_img_tensor, (hx, cx))
        pert_value, pert_logit, _ = model(model_inp)
        pert_values.append(pert_value.detach().numpy()[0])
        pert_logits.append(pert_logit.detach().numpy()[0])

    # Generate actor saliency maps
    print_helper(f"[{img_idx}]Generating actor saliency maps", verbose)

    actor_sal_maps[sal_idx] = saliency_gen(ref_logit, pert_logits, pert_masks)

    # Generate critic saliency maps
    print_helper(f"[{img_idx}]Generating critic saliency maps", verbose)

    critic_sal_maps[sal_idx] = saliency_gen(ref_value, pert_values, pert_masks)


def visualize_saliency(
    frame: np.ndarray,
    actor_sal_map: np.ndarray,
    critic_sal_map: np.ndarray,
) -> np.ndarray:
    """Creates visualization of saliency by alpha-blending the provided frame and
    saliency maps. The actor and critic saliency maps are scaled to [0,255] to
    match the range of the frame and then blended as the blue and red channels
    respectively.
    """
    # Scale saliency maps to [0, 255] in respective color channel
    sal_colors = np.zeros((210, 160, 3), dtype=np.uint8)
    sal_colors[35:195, :, 2] = (255.0 * cv2.resize(src=actor_sal_map, dsize=(160, 160))).astype("uint8")
    sal_colors[35:195, :, 0] = (255.0 * cv2.resize(src=critic_sal_map, dsize=(160, 160))).astype("uint8")

    return cv2.addWeighted(frame, 0.4, sal_colors, 0.6, 0)


def app(
    start_frame: int,
    end_frame: int,
    perturber: PerturbImage,
    saliency_gen: GenerateClassifierConfidenceSaliency,
) -> None:
    """App for running breakout and peforming saliency generation."""
    global actor_sal_maps
    global critic_sal_maps

    # Initialize map arrays to correct size
    actor_sal_maps = [None] * (end_frame - start_frame + 1)
    critic_sal_maps = [None] * (end_frame - start_frame + 1)

    # Use threaded computation
    def call_wrapper(idx: int) -> None:
        """Wrapper for gen_sal_maps."""
        return gen_sal_maps(idx, idx - start_frame, perturber, saliency_gen)

    tpe = concurrent.futures.ThreadPoolExecutor(max_workers=multiprocessing.cpu_count())
    list(tpe.map(call_wrapper, range(start_frame, end_frame + 1)))
    tpe.shutdown(wait=True)

    # Write out videos
    print("Creating video")
    fps = 30
    vid_writer = cv2.VideoWriter("data/breakout_saliency.mp4", cv2.VideoWriter_fourcc(*"vp09"), fps, (w, h))

    for img_idx in range(start_frame, end_frame + 1):
        sal_idx = img_idx - start_frame

        frame = visualize_saliency(history["ins"][img_idx], actor_sal_maps[sal_idx], critic_sal_maps[sal_idx])

        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)  # Convert to BGR to meet cv2 convention

        vid_writer.write(frame)

    vid_writer.release()

    print("Done.")

Perturbation and Saliency Implementations

For this example we will use the SlidingRadial perturbation implementation with blurring using sigma values of (5,5) and stride of (5,5). Our filled image here will be a blurred version of the original image. We try to use the same parameters as in the original implementation.

The SquaredDifferenceScoring heatmap generation implementation is used here, following the authors’ approach in the paper. This will give an absolute sense of saliency for both the actor and critic models without distinguishing positive vs. negative.

from xaitk_saliency.impls.gen_classifier_conf_sal.squared_difference_scoring import SquaredDifferenceScoring
from xaitk_saliency.impls.perturb_image.sliding_radial import SlidingRadial

window_perturber = SlidingRadial(radius=(1, 1), stride=(5, 5), sigma=(5, 5))
sal_gen = SquaredDifferenceScoring()

Calling the Application

An arbitrary set of frames is chosen for saliency generation using our application and the resulting video is displayed below.

From the video it is apparent that the area above the paddle and around the ball seem to be very important to both the value and policy functions. We can also see that the value function is heavily affected by holes in the blocks. This suggests that the agent has learned to create these holes so that it can get the ball behind the blocks to earn a high reward quickly.

app(start_frame=638, end_frame=1238, perturber=window_perturber, saliency_gen=sal_gen)
Creating video
Done.

Here we show saliency (actor in blue, critic in red) overlaid on a single frame from the game.

f = plt.figure(figsize=[8, 12])
frame_ix = 1070

img = history["ins"][frame_ix]

viz = visualize_saliency(img, actor_sal_maps[frame_ix - 1020], critic_sal_maps[frame_ix - 1020])

plt.imshow(viz)

for a in f.axes:
    a.get_xaxis().set_visible(False)
    a.get_yaxis().set_visible(False)
plt.show()
../_images/94b5e16806c5f9c6476aa84b63c39654d5c25741f0d22e5eb72d7951ccb7f570.png

Finally, we show a recorded video from the game with overlaid saliency (actor in blue, critic in red). This video is saved as breakout_saliency.mp4 in the current directory.

Note that this file will not exist, and the video will not be visible, until the notebook is run.

%%HTML
<div align="center">
<video width="50%" controls>
    <source src="data/breakout_saliency.mp4" type="video/mp4">
</video>
</div>