Your Own Few-Shot Classification Model Ready in 15mn with PyTorch

Computer Vision Machine Learning 14 min read

Intrigued by few-shot learning? Always wanted to use state-of-the-art algorithms for your project, but never knew where to start? We got you.


I have been working on few-shot classification for a while now. And the more I talk about it, the more the people around me seem to feel that it's some kind of dark magic. Even sadder: I noticed that very few actually used it on their projects. I think that's too bad, so I decided to make a tutorial so you'll have no excuse to deprive yourself of the power of few-shot learning methods.

In 15 minutes and just a few lines of code, we are going to implement the Prototypical Networks. It's the favorite method of many few-shot learning researchers (~2000 citations in 3 years), because 1) it works well, and 2) it's incredibly easy to grasp and to implement.

If you want to experiment by yourself, you can open our notebook in Colab. But if you only want to see the code and the results, just scroll down!

Getting started

First, let's install the tutorial GitHub repository and import some packages.

!pip install easyfsl

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm

from easyfsl.data_tools import TaskSampler
from easyfsl.utils import plot_images, sliding_average

Now, we need a dataset. I suggest we use Omniglot, a popular MNIST-like benchmark for few-shot classification. It contains 1623 characters from 50 different alphabets. Each character has been written by 20 different people.

Bonus: it's part of the torchivision package, so it's very easy to download and work with.

image_size = 28

train_set = Omniglot(
root="./data",
background=True,
transform=transforms.Compose(
[
transforms.Grayscale(num_output_channels=3),
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
),
download=True,
)
test_set = Omniglot(
root="./data",
background=False,
transform=transforms.Compose(
[
transforms.Grayscale(num_output_channels=3),
transforms.Resize([
int(image_size * 1.15), int(image_size * 1.15)
]),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
]
),
download=True,
)

Two notes here:

  1. To build Omniglot, background=True selects the train set, background=False selects the test set. It's the nomenclature from the original paper, we just have to deal with it
  2. Omniglot images have 1 channel, but our model will expect 3-channel images, thus the Grayscale transformation.

Discovering Prototypical Networks

The Few-Shot Classification problem

Let's take some time to grasp what few-shot classification is. Simply put, in a few-shot classification task, you have a labeled support set (which kind of acts like a catalog) and a query set. For each image of the query set, we want to predict a label from the labels present in the support set.

A few-shot classification model has to use the information from the support set in order to classify query images. We say few-shot when the support set contains very few images for each label (typically less than 10). The figure below shows a 3-way 2-shots classification task. "3-way" means "3 different classes" and "2-shots" means "2 examples per class". We expect a model that has never seen any Saint-Bernard, Pug or Labrador during its training to successfully predict the query labels. The support set is the only information that the model has regarding what a Saint-Bernard, a Pug or a Labrador can be.

A few-shot classification task, with a support set and a query set from the same classes.
A 3-way 2-shots classification task: Saint-Bernard, Pug or Labrador?

Prototypical Networks: A Metric Learning algorithm

Most few-shot classification methods are metric-based. It works in two phases : 1) they use a CNN to project both support and query images into a feature space, and 2) they classify query images by comparing them to support images. If, in the feature space, an image is closer to pugs than it is to labradors and Saint-Bernards, we will guess that it's a pug.

From there, we have two challenges :

  1. Find the good feature space. This is what convolutional networks (CNN) are for. A CNN is basically a function that takes an image as input and outputs a representation (or embedding) of this image in a given feature space. The challenge here is to have a CNN that will project images of the same class into representations that are close to each other, even if it has not been trained on objects of this class.
  2. Find a good way to compare the representations in the feature space. This is the job of Prototypical Networks.
Few-Shot Classification with Prototypes
In the feature space, x is closer to c2 than to the other prototypes, so we predict that it belongs to class 2.

From the support set, Prototypical Networks compute a prototype for each class, which is the mean of all embeddings of support images from this class. Then, each query is simply classified as the nearest prototype in the feature space, with respect to euclidean distance.

If you want to learn more about how this works, I explain it this article about few-shot image classification. But now, let's get to coding.

Let’s get to coding

In the code below (modified from this), we simply define Prototypical Networks as a torch module, with a forward() method. You may notice 2 things.

  1. We initiate PrototypicalNetworks with a backbone. This is the feature extractor we were talking about. Here, we use as backbone a ResNet18 pretrained on ImageNet, with its head chopped off and replaced by a Flatten layer. Then, the output of the backbone, for an input image, will be a 512-dimensional feature vector.
  2. The forward method doesn't only take one input tensor, but 3: in order to predict the labels of query images, we also need support images and labels as inputs of the model.
class PrototypicalNetworks(nn.Module):
def __init__(self, backbone: nn.Module):
super(PrototypicalNetworks, self).__init__()
self.backbone = backbone

def forward(
self,
support_images: torch.Tensor,
support_labels: torch.Tensor,
query_images: torch.Tensor,
) -> torch.Tensor:
"""
Predict query labels using labeled support images.
"""

# Extract the features of support and query images
z_support = self.backbone.forward(support_images)
z_query = self.backbone.forward(query_images)

# Infer the number of classes from the labels of the support set
n_way = len(torch.unique(support_labels))
# Prototype i is the mean of all support features vector with label i
z_proto = torch.cat(
[
z_support[torch.nonzero(support_labels == label)].mean(0)
for label in range(n_way)
]
)

# Compute the euclidean distance from queries to prototypes
dists = torch.cdist(z_query, z_proto)

scores = -dists
return scores


convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()

model = PrototypicalNetworks(convolutional_network).cuda()

Now we have a model! Note that we used a pretrained feature extractor, so our model should already be up and running. Let's see that.

Evaluating a Few-Shot Learning model

Loading few-shot classification tasks with PyTorch

We are going to create a dataloader that will feed few-shot classification tasks to our model. But a regular PyTorch dataloader will feed batches of images, with no consideration for their label or whether they are support or query. We need 2 specific features in our case.

  1. We need images evenly distributed between a given number of classes;
  2. and we need them split between support and query sets.

For the first point, I wrote a custom sampler: it first samples n_way classes from the dataset, then it samples n_shot + n_query images for each class (for a total of n_way * (n_shot + n_query) images in each batch). For the second point, I have a custom collate function to replace the built-in PyTorch collate_fn. This baby feeds each batch as the combination of 5 items:

  1. support images
  2. support labels between 0 and n_way
  3. query images
  4. query labels between 0 and n_way
  5. a mapping of each label in range(n_way) to its true class id in the dataset (it's not used by the model but it's very useful for us to know what the true class is)

You can see that in PyTorch, a DataLoader is basically the combination of a sampler, a dataset and a collate function (and some multiprocessing voodoo): sampler says which items to fetch, the dataset says how to fetch them, and the collate function says how to present these items together. If you want to dive into these custom objects, they're here.

N_WAY = 5 # Number of classes in a task
N_SHOT = 5 # Number of images per class in the support set
N_QUERY = 10 # Number of images per class in the query set
N_EVALUATION_TASKS = 100

The sampler needs a dataset with a "labels" field. Check the code if you have any doubt!

test_set.labels = [
instance[1] for instance in test_set._flat_character_images
]

test_sampler = TaskSampler(
test_set,
n_way=N_WAY,
n_shot=N_SHOT,
n_query=N_QUERY,
n_tasks=N_EVALUATION_TASKS,
)

test_loader = DataLoader(
test_set,
batch_sampler=test_sampler,
num_workers=12,
pin_memory=True,
collate_fn=test_sampler.episodic_collate_fn,
)

Evaluating the model

A visual example

We created a dataloader that will feed us with 5-way 5-shot tasks (the most common setting in the few-shot literature). Now, as every data scientist should do before launching opaque training scripts, let's take a look at our dataset.

(
example_support_images,
example_support_labels,
example_query_images,
example_query_labels,
example_class_ids,
) = next(iter(test_loader))

plot_images(example_support_images, "support images", images_per_row=N_SHOT)
plot_images(example_query_images, "query images", images_per_row=N_QUERY)

This gives us the following plots. For both support and query sets, we have one line for each class.

How does our model perform on this task?

model.eval()
example_scores = model(
example_support_images.cuda(),
example_support_labels.cuda(),
example_query_images.cuda(),
).detach()

_, example_predicted_labels = torch.max(example_scores.data, 1)

print("Ground Truth / Predicted")
for i in range(len(example_query_labels)):
print(
f"{test_set._characters[example_class_ids[example_query_labels[i]]]} / {test_set._characters[example_class_ids[example_predicted_labels[i]]]}"
)

The output doesn't look bad: keep in mind that the model was trained on very different images, and has only seen 5 examples for each class!

Ground Truth / Predicted 
Angelic/character18 / Angelic/character18
Angelic/character18 / Angelic/character18
Angelic/character18 / Atlantean/character11
Angelic/character18 / Angelic/character18
Angelic/character18 / Angelic/character18
Angelic/character18 / Angelic/character18
Angelic/character18 / Angelic/character18
Angelic/character18 / Angelic/character18
Angelic/character18 / Angelic/character18
Angelic/character18 / Angelic/character18
Kannada/character23 / Kannada/character23
Kannada/character23 / Kannada/character23
Kannada/character23 / Kannada/character23
Kannada/character23 / Kannada/character23
Kannada/character23 / Kannada/character23
Kannada/character23 / Kannada/character23
Kannada/character23 / Kannada/character23
Kannada/character23 / Tibetan/character40
Kannada/character23 / Kannada/character23
Kannada/character23 / Kannada/character23
Gurmukhi/character16 / Gurmukhi/character16
Gurmukhi/character16 / Tibetan/character40
Gurmukhi/character16 / Gurmukhi/character16
Gurmukhi/character16 / Atlantean/character11
Gurmukhi/character16 / Tibetan/character40
Gurmukhi/character16 / Gurmukhi/character16
Gurmukhi/character16 / Atlantean/character11
Gurmukhi/character16 / Gurmukhi/character16
Gurmukhi/character16 / Gurmukhi/character16
Gurmukhi/character16 / Gurmukhi/character16
Tibetan/character40 / Tibetan/character40
Tibetan/character40 / Tibetan/character40
Tibetan/character40 / Tibetan/character40
Tibetan/character40 / Kannada/character23
Tibetan/character40 / Atlantean/character11
Tibetan/character40 / Tibetan/character40
Tibetan/character40 / Atlantean/character11
Tibetan/character40 / Tibetan/character40
Tibetan/character40 / Tibetan/character40
Tibetan/character40 / Tibetan/character40
Atlantean/character11 / Atlantean/character11
Atlantean/character11 / Atlantean/character11
Atlantean/character11 / Atlantean/character11
Atlantean/character11 / Atlantean/character11
Atlantean/character11 / Atlantean/character11
Atlantean/character11 / Atlantean/character11
Atlantean/character11 / Atlantean/character11
Atlantean/character11 / Atlantean/character11
Atlantean/character11 / Atlantean/character11
Atlantean/character11 / Atlantean/character11

Evaluating on the test set

Now that we have a first idea, let's see more precisely how good our model is.

def evaluate_on_one_task(
support_images: torch.Tensor,
support_labels: torch.Tensor,
query_images: torch.Tensor,
query_labels: torch.Tensor,
) -> [int, int]:
"""
Returns the number of correct predictions of query labels, and the total
number of predictions.
"""
return (
torch.max(
model(
support_images.cuda(),
support_labels.cuda(),
query_images.cuda(),
)
.detach()
.data,
1,
)[1]
== query_labels.cuda()
).sum().item(), len(query_labels)


def evaluate(data_loader: DataLoader):
# We'll count everything and compute the ratio at the end
total_predictions = 0
correct_predictions = 0

# eval mode affects the behaviour of some layers (such as batch normalization or dropout)
# no_grad() tells torch not to keep in memory the whole computational graph (it's more lightweight this way)
model.eval()
with torch.no_grad():
for episode_index, (
support_images,
support_labels,
query_images,
query_labels,
class_ids,
) in tqdm(enumerate(data_loader), total=len(data_loader)):

correct, total = evaluate_on_one_task(
support_images, support_labels, query_images, query_labels
)

total_predictions += total
correct_predictions += correct

print(
f"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions/total_predictions):.2f}%"
)


evaluate(test_loader)

With absolutely zero training on Omniglot images, and only 5 examples per class, we achieve around 86% accuracy! Isn't this a great start?

100%|██████████| 100/100 [00:06<00:00, 16.41it/s]
Model tested on 100 tasks. Accuracy: 86.32%

Now that you know how to make Prototypical Networks work, you can see what happens if you tweak it a little bit (change the backbone, use other distances than euclidean...) or if you change the problem (more classes in each task, less or more examples in the support set, maybe even one example only, but keep in mind that in that case Prototypical Networks are just standard nearest neighbour).

When you're done, you can scroll further down and learn how to meta-train this model, to get even better results.

Training a meta-learning algorithm

Preparing the training data

Let's use the "background" images of Omniglot as training set. Here we prepare a data loader of 40 000 few-shot classification tasks on which we will train our model. The alphabets used in the training set are entirely separated from those used in the testing set. This guarantees that at test time, the model will have to classify characters that were not seen during training.

Note that we don't set a validation set here to keep this notebook concise, but keep in mind that this is not good practice and you should always use validation when training a model for production.

N_TRAINING_EPISODES = 40000
N_VALIDATION_TASKS = 100

train_set.labels = [
instance[1] for instance in train_set._flat_character_images
]

train_sampler = TaskSampler(
train_set,
n_way=N_WAY,
n_shot=N_SHOT,
n_query=N_QUERY,
n_tasks=N_TRAINING_EPISODES,
)
train_loader = DataLoader(
train_set,
batch_sampler=train_sampler,
num_workers=12,
pin_memory=True,
collate_fn=train_sampler.episodic_collate_fn,
)

Episodic training

We will keep the same model. So our weights will be pre-trained on ImageNet. If you want to start a training from scratch, feel free to set pretrained=False in the definition of the ResNet.

Here we define our loss and our optimizer (cross entropy and Adam, pretty standard), and a fit method. This method takes a classification task as input (support set and query set). It predicts the labels of the query set based on the information from the support set; then it compares the predicted labels to ground truth query labels, and this gives us a loss value. Then it uses this loss to update the parameters of the model. This is a meta-training loop.

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


def fit(
support_images: torch.Tensor,
support_labels: torch.Tensor,
query_images: torch.Tensor,
query_labels: torch.Tensor,
) -> float:
optimizer.zero_grad()
classification_scores = model(
support_images.cuda(), support_labels.cuda(), query_images.cuda()
)

loss = criterion(classification_scores, query_labels.cuda())
loss.backward()
optimizer.step()

return loss.item()

To train the model, we are just going to iterate over a large number of randomly generated few-shot classification tasks, and let the fit method update our model after each task. This is called episodic training.

log_update_frequency = 10

all_loss = []
model.train()
with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
for episode_index, (
support_images,
support_labels,
query_images,
query_labels,
_,
) in tqdm_train:
loss_value = fit(
support_images,
support_labels,
query_images,
query_labels,
)
all_loss.append(loss_value)

if episode_index % log_update_frequency == 0:
tqdm_train.set_postfix(
loss=sliding_average(all_loss, log_update_frequency)
)

Note that this took me 20mn on an RTX 2080.

Measuring our improvement

Now let's see if our model got better!

evaluate(test_loader)

Our model’s accuracy improved by 13%!

100%|██████████| 100/100 [00:06<00:00, 16.08it/s]
Model tested on 100 tasks. Accuracy: 98.38%

It's not surprising that the model performs better after being further trained on Omniglot images than it was with its ImageNet-based parameters. However, we have to keep in mind that the classes on which we just evaluated our model were still not seen during training, so 98% (with a 12% improvement over the model trained on ImageNet) seems like a decent performance.


What have we learned?

  • What a Prototypical Network is and how to implement one in 15 lines of code.
  • What Omniglot is and how to evaluate few-shot models on it
  • How to use custom PyTorch objets to sample batches in the shape of a few-shot classification tasks.
  • How to use meta-learning to train a few-shot algorithm.

What's next?

  • Take the notebook in your own hands, tweak everything that there is to tweak. It's the best way to understand what does what.
  • Implement other few-shot learning methods, such as Matching Networks, Relation Networks, MAML...
  • Try other ways of training. Episodic training is not the only way to train a model to generalize to new classes!
  • Experiment on other, more challenging few-shot learning benchmarks, such as CUB or Meta-Dataset.
  • If you liked this tutorial, feel free to ⭐ give us a star on Github
  • Contribute! The companion repository of this tutorial is meant to become a boilerplate, a source of useful code that that newcomers can use to start their few-shot learning projects.

Thank you for your time and good luck on your few-shot learning journey!

few shot pytorch tutorial
Get our Battle-Tested Tutorials Delivered Straight Into Your Inbox Every Week
Sign up for our newsletter