Transfer Learning with Tensor Flow using the CIFAR-10 dataset and Google Compute Engine

Part 1: Running a Jupyter notebook on Google Compute Engine

The first step is to get a Jupyter notebook up and running. Of course it isn’t necessary to use Jupyter to run all the code but I have a personal preference for the workflow it offers. There were two ways that I was able to achieve this. The first-to create a new instance and download all the libraries I needed.

This is a time consuming process, as sometimes there are compatibility issues when using a different version of the library. I experienced this when trying to use tensorflow after an anaconda installation. The workaround was to create a separate anaconda environment and install tensorflow there.

To run the Jupyter notebook, I created a reserved a static IP address, changed the Jupyter configuration files and ran the notebook on the server. However note that this makes the notebook publicly accessible. While, Jupyter does provide a layer of security by requiring web token authorization, it’s not a good practice to do this.

The second way I tried this was to create a VM instance from a machine image. This was a much faster way to get up and running. It involved first installing the and authenticating the Google CLI on my local machine, and then creating an instance from the command line. There’s a really handy tutorial here that helped me get up and running pretty quickly. I used a ‘tf-latest-cpu’ image on an ‘h1-highmem-8’ which is an 8 core cpu with 200 GB memory space. I’m also using a preemptible instance which cuts costs but the downside is that the machine only runs for 24 hours and can be interrupted at any time. However, Google does state that the ‘preemption rates for smaller machine types with less than 32 cores are also historically lower than for larger machine types’ and I didn’t intend on running the machine for longer than 24 hours anyway.

Once the instance has been created, the next step is to connect to it via SSH. This can be achieved both through the User Interface or through the command line. Note that you’re only connected to the instance as long as the SSH connection persists.

connecting through command line
connecting from the browser

Now all I have to do is navigate to http://localhost:8080/tree?

Part 2: What is transfer learning ?

Despite how accessible computation power to train neural networks has become today – a significant caveat still remains: obtaining enough training data for the model to learn from and produce satisfactory results.

Transfer learning is a process by which a pretrained model is reused for a similar classification task. For example, if we had a network already trained to classify types of buses, then there’s a good chance we can reuse this to classify types of trucks. The process commonly involves freezing a few of the lower layers, which means that we preserve the previously learned parameters and as an added advantage reduces the time required for training our entire model.

Many popular model architectures as well as their pre-trained weights are available in the TensorFlow keras applications module.

For the following task I’ve experimented with two model architectures – VGG16 and ResNet50.

Part 3: CIFAR10 with ResNet50

The CIFAR-10 dataset available here is a dataset of 32 x 32 colour images belonging to 10 classes. It can be directly downloaded via TensorFlow and is conveniently subdivided into training and testing datasets.

Additionally I created a validation dataset from a part of the training dataset. It’s important to have a validation dataset since it helps us determine whether the model has really learnt relevant features or simply memorized the training dataset.I also encoded the labels into a one-hot format to use with softmax later.

splitting the data
converting labels to one-hot format
using the summary() method to view the layers in a resnet50 architecture

Next, I instantiated a ResNet50 architecture with pretrained weights from the image net challenge ( here’s the documentation). Since the default input shape was 224 x 224, I decided to resize my images, before feeding it to the model. Initially, I intended to use skimage.resize, however resizing 50,000 images to 224 x 224 ended up taking up too much space and I found an easier approach as mentioned in this super useful answer and that was to create a lambda layer.

A lambda layer basically performs custom operations on the data. It can be added to the pipeline when creating a Sequential model. So I created a lambda layer and resized my images with tf.image.resize.

Since the training process was going to take quite a while, it’s also useful to save the model between epochs so as to not have to train from the beginning if the process is interrupted. Here’s a useful medium article illustrating how to do that.

On top of the lambda layer, I added my base model (ResNet50), an averaging layer, a dense layer with 10 units( since CIFAR-10 has 10 classes) and a softmax layer to predict the final output. In this example I’m not going to train the base model at all.

All that’s left is to compile the model and fit the training dataset. I’m going to run the model for 10 epochs with a batch size of 64.

After the 10th epoch, I notice that while the training accuracy has increased, the validation accuracy has converged to around 65 % . This is a sign of overfitting.

Part 4: CIFAR10 with VGG16

For this example, I followed the same steps as the previous example, but instead of increasing the size of my image, I changed the shape of the input layer of the network. As a base model, I’m using VGG16. However this time, I’m going to train the entire network and see what kind of results I get.

The VGG16 architecture

I’m also using a technique here called learning rate annealing. This basically reduces the learning rate by a factor if certain conditions are met. Since a high learning rate is good for finding the minima faster, its often good to start with a higher learning rate and reduce it if the validation accuracy does not change

Here I’m going to reduce the learning rate by a factor of 0.01 if the validation accuracy doesn’t change after three epochs.

However as you can see, the model seems to be overfit once more.

Further improvements:

Since, the results of the above two examples weren’t very good; in the next post I’m going to try to examine the reasons behind this and how to overcome them. I’m also going to try to create a binary classifier with a dataset from images I collect myself and see what happens.

Design a site like this with WordPress.com
Get started