August 28, 2019

Interpretability of Deep Learning Models with Tensorflow 2.0

An introduction to interpretability methods to ease neural network training monitoring.
Grad CAM method on ‘deer’ ImageNet class (Original photo by Asa Rodger on Unsplash)
Grad CAM method on ‘deer’ ImageNet class (Original photo by Asa Rodger on Unsplash)

This article dives into the tf-explain library. It provides explanations on interpretability methods, such as Grad CAM, with Tensorflow 2.0.

Did you ever wait long hours for training that turns out unsuccessful? I have, and not just once. So I began looking into tools that could help me anticipate and debug neural networks that come out of training.

This article is a summary of what I found:

  • A visual logging to check pipeline integrity

  • Interpretability methods to have insights on the inside of neural networks

All the methods presented in this article are implemented in tf-explain, a library built for interpretability with TensorFlow 2.0. Check out the introduction!

Let’s look now at those implementations.

Look at what comes in the network

The first cause for failed trainings is simply not giving the network what you want him to have. Visualizing the inputs is crucial, and for that, I use the VisualLogging library. Dropping sample images as logs at different key moments of your pipeline (loading, resizing, data augmentation) can help you catch eventual undesired effects.

A Visual Logging setup

Logs are directly dropped into an HTML page, which you can scroll and inspect.

Example of a possible Data Augmentation Pipeline
Example of a possible Data Augmentation Pipeline

Monitor Convolutional Kernels

Let’s now dive into interpretability methods. Key components of convolutional nets are kernels. All that matters is what the kernel is learning, how its output contributes to the final classification.

Outputs of ResNet50’s activation_1 layer for a sample cat
Outputs of ResNet50’s activation_1 layer for a sample cat

Intermediate Layers Visualization

A first step is to simply visualize what comes out of the activation layers. Does the output still look relevant? Or does it look like random noise? By examining how the image transits through the network, you can validate that it focuses on the right regions.

Subgraph of VGG16 to observe activations
Subgraph of VGG16 to observe activations

Extracting the output of an intermediate layer with Tensorflow is fairly easy. You start from your whole model and extract a subpart of the graph. The code below shows how to obtain the outputs of the activation_1 layer from a Resnet50 model.

Intermediate Layers Visualization

Kernel Inspection

Seeing what is coming out of a layer is great, but what if we could understand what makes a kernel activate?

Visualization of VGG Filters
Visualization of VGG Filters

The idea behind this visualization is to generate an input to the network that maximizes the reaction of a given kernel filter. Therefore, we create a sub-model which stops at the target layer. The loss function we seek to maximize is the mean of this activation layer’s output. The starting point is taking some random noise as input. Then, we backpropagate the gradients to perform gradient ascent on the noise. Iteratively, we build an input that makes the filter’s reaction stronger and stronger.

Gradient Ascent on Input to Visualize Kernels
Gradient Ascent on Input to Visualize Kernels

With Tensorflow, the implementation of this method is only 4 steps:

  • perform the initial subgraph creation (same as before)

  • use the GradientTape object to capture the gradients on the input

  • get the gradients with tape.gradient

  • perform the gradient ascent with assign_add on the initial variable.

The example here is minimal to keep the code simple. Many techniques exist to improve those kernel visualizations (regularization, upscaling). If you are interested in this subject, I strongly encourage you to read this blog post on Feature Visualization.

What Makes the Neural Network’s Decision

Visualizing the kernels and the intermediate layers can help detect weird behaviors. However, it does not give any insights on why a neural network makes a specific decision. The next few methods are ways to visualize what part of the input is influencing the output value.

Occlusion Sensitivity

The idea behind Occlusion Sensitivity is to hide parts of the image and see the impact on the neural network’s decision for a specific class.

On the animation below, we run a blue patch over a cat image and extract the confidence at each step. When the patch goes over the cat, confidence drops, so we can identify the region behind the patch as hot. When the patch does not occlude the cat, the confidence stays even or eventually goes up. This happens because we potentially hide elements that degrade the performance.

Heatmap generation process for class Cat
Heatmap generation process for class Cat

The heatmap generated carries the information “Does this part of the image helps to improve confidence”. Here, the resolution is pretty poor. You can improve it by varying the patch size to capture influences from micro to macro zones of the image.

The process to generate the heatmap is decomposed simply :

  • Create a batch of images with patches applied

  • Run predictions

  • Save confidence for the target class

  • Regroup confidences in the resulting map

Occlusion Sensitivity Implementation

Note: This code translates the algorithm logic, but should be optimized by first generating all the patched images and then running the predictions in batches.

Occlusion Sensitivity Map
Occlusion Sensitivity Map

Class Activation Maps

Another type of methods uses directly the gradients to determine the relevant zones. Class Activations Maps (CAM) and more specifically Grad-CAM methods (which is implemented below) check the importance of output filters (see section Intermediate Layers Visualization above) towards the final decision.

Given those convolutional filters (of shape WxHxN), we compute the gradients towards the class score (same shape WxHxN). To establish the importance of each filter in the decision, we take the average of its weights (with shape 1x1xN) as a judge, and multiply each map by its corresponding weights. Then, we sum up all those ponderated maps into a final heatmap. If an activation map has been lightened up during forward pass, and if its gradients are large, it means the region which is activated has a large impact on the decision.

heatmap

The implementation follows this idea and does not differ much from the previous algorithms presented.

However, a subtlety used in the Grad CAM paper is called Guided BackpropagationIt consists in eliminating elements that acts negatively towards the decision, by zeroing-out the negative gradients or gradients associated with a negative value of the filter.

(left) Grad CAM, (right) Grad CAM + Guided Backpropagation
(left) Grad CAM, (right) Grad CAM + Guided Backpropagation

Tensorflow offers the tf.RegisterGradient method to define a new gradient method, which combined with the gradient_override_map helps switch the behavior for our ReLU layers.

Unfortunately, if you try to run this operation, Tensorflow informs you that tf.cast is no longer supported in version 2.0:tf.GradientTape.gradients() does not support graph control flow operations like tf.cond or tf.while at this time.

As we use the guide only at inference time, we can perform this operation after gradients computation rather than during. It implies a small change after the GradientTape call.

Full gist for the Grad CAM implementation (with guided backpropagation) is available here.

Grad CAM applied on a cat video
Grad CAM applied on a cat video

Those methods are all implemented in tf-explain, which you can use on your trained models or in Keras callbacks. I’ll write more articles with additional methods as we integrate them into tf-explain, follow me on Twitter to get notified!

Additional Resources

On Feature Visualizations:

On Activation Maps:

Tips for Neural Network Training:

Thanks to Emna Kamoun, Bastien Ponchon, Juliep, Jeremy Joslove, and Quentin Febvre. 

Edge detection, tutorial, knowledge

Edge Detection in Opencv 4.0, A 15 Minutes Tutorial

This tutorial will teach you, with examples, two OpenCV techniques in python to deal with edge detection.

The best of AI of June 2019

The Best of AI: New Articles Published This Month (June 2019)

10 data articles handpicked by the Sicara team, just for you.

jupyter graph

Why Jupyter Is Not My Ideal Notebook

From notebook prototyping to production the right way.