Building a Reverse Image Search AI using PyTorch

Implementing Deep Visual-Semantic embedding model in Pytorch
trained to identify visual objects using both labelled image data as well as semantic information gleaned from the unannotated text.

This project is inspired by the fastai lecture on DeViSe.

Introduction

DeViSe: Deep Visual-Semantic embedding model, is a paper from Google, which introduces the idea of training an Image recognition model with word vectors of the labels as target values instead of the labels themselves, this allows the model to learn the semantic meaning of the labels.

Jeremy Howard trained a model on the entire Imagenet dataset, which I cannot afford to train on my machine, So I decided to use the Tiny-Imagenet dataset which was actually made for CS231n course by Stanford, which didn’t give me great results as many of the labels were compound words, just using single word gave an okay result, I tried to average the word vectors for compound word labels, it resulting in a poor model. After a while, I found Caltech256 dataset, which had 256 classes and the number of compound words were few! So I decided to use it instead. I also decided to not train a model completely from scratch as the pretrained Imagenet models are already available why not use them to just fine-tune them.

If you’re just interested in the results go to the Results section.

First, import all the libraries that we will need for implementing DeViSe.

Declare all paths we will need later.

I used fasttext for word vectors, you can use anything that gives you trained word vectors(word2vec, GloVe…). I already had the common crawl version of the fasttext vectors for English so I just used them instead of downloading a 9gb wiki version, you can choose anything you want.

Load them to memory by executing:

Let's view some word vectors.

Read all the words available in the fasttext file.

Now, I stored the lowercase versions of the most frequent 1 million words and their word vectors for easy access. This is optional, you can directly access them from fasttext anyway.

After manually renaming the classes(like ice-cream -> icecream) there were 231 labels for which the word vectors existed.

I took out the images that belong to certain 10 classes for the testset.

Shuffle and split the remaining dataset into training(75%) and validation dataset(25%).

The dataset is pretty straight forward, we already have the list of images paths, and an array of respective word vector.

As I want to use the pretrained model to train the model, I need to normalise the images properly using transforms.

To look at the images in the dataset, we need to denormalise the data.

The dataset is now ready to be loaded, to dispatch them as batches we use the Pytorch Dataloader. I used batch_size = 64.

Let’s view a batch of images.

Now that we have the dataset set up, let’s create the model!

The loss function I used is 1 - F.cosine_similarity(out,targets).mean()
as Jeremy suggests in his lecture, this makes more sense right? We need to maximize the similarity instead of forcing the model to output the same word vector for all the images from the same class, which is no different than simply classifying the images.

You can use any model with almost any head, Shouldn’t matter much as long as the model outputs 300 values.

Move the model and the data to GPU.

Define the fit function

I’m freezing all the pretrained layers except the newly added head, this is because all the images in Caltech256 more or less exist in Imagenet, So a model trained on Imagenet can easily adapt instead of learning new features.

Finally, we’re ready to train the model.

Save the model for later use.

Results

First, I used the validation set to do see how well the model is performing on images it has never seen. To do that I obtained the predicted word vectors first.

Next, I used the awesome nmslib library to store the predicted word vectors in an angulardist space. For more details look at https://github.com/nmslib/nmslib/blob/master/manual/README.md

Once the word vectors are stored and indexed, we can use K-nearest neighbours to perform the search on the space.

Now that we have everything in place we can start searching! To visualise the results, just plot them.

First, let’s search for something that exists in the training set labels.
How about some “fireworks”?

The model performed great! It retrieved us 8/9 fireworks and 1 lightning images. You can understand why the model thought lightning as fireworks right? Both are bright explosions in the sky.

Now, let’s look for something that doesn’t exist in the labels.

There is no label called “sea” in the dataset. Let’s see what the model spits out.

Wow! You can see every image retrieved has the sea in some part of the image.

Let’s search for some food.

As expected every image retrieved is a food.

Let’s now see if we can find similar images to the one we have. I googled for “elephant” and downloaded this image.

Let’s see if the model can find elephants from the dataset.

Looks like it can!

If the model can find similar images, it should find whats’s in the image first right, let’s directly get that. Unfortunately for some reason, fasttext library not yet has a function to get similar word vectors given a word vector, Gensim to the rescue!

But before loading the word vectors through gensim library, let’s delete the previously loaded file from memory by executing:

Loading the model.

Store the predicted vector.

Looking at the result.

Looks like the model is obsessed with “monitor”. But anyway it also predicted “computer” and all the other stuff related to a desktop. I don’t know how a treadmill is related to a desktop though. :D

We have so far seen the output from the dataset whose labels or images have existed during the training phase. Let’s try to find an image and class that the model has never seen before.

There are classes that exist in the test set.

Build and create index.

Let’s search for “bear”.

The model did great given that the model never saw a bear!

Let’s search for something creative.

I never expected the model would know what does “nemo” mean, it surprised me by finding octopus is the closest thing from the test set to the term “nemo”.

I tried the same steps with resnet34 and resnet50, even though the loss remarkably gets lower and lower by training bigger models, I don’t see the point in training them as they all seem to perform similarly most of the times, sometimes the model predicted completely different labels than others (like resnet50 predicted “keyboards” mostly, while we saw resnet18 predicted “monitor” mostly).

References

  1. DeViSE: A Deep Visual-Semantic Embedding Model , Andrea Frome, Greg S. Corrado, Jon Shlens, Samy Bengio, Jeff Dean, Marc’Aurelio Ranzato, and Tomas Mikolov, NIPS, 2013.
  2. Lecture 11 from fast.ai by Jeremy Howard and Rachel Thomus.
  3. Caltech-256 Object Category Dataset, Griffin, Gregory and Holub, Alex and Perona, Pietro (2007) Caltech-256 Object Category Dataset. California Institute of Technology. Download link.
  4. Fasttext. Link.
  5. nmslib. Link.
  6. Jovian.ml for hosting the notebook.

Machine Learning Enthusiast

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store