Kaggle Competition: Self-Supervised Learning for Cassava Leaf Disease Classification

Hello everyone, in this thread I hope to detail my approach to working on the Cassava Leaf Disease Classification Competition on Kaggle.

About The Competition

This is a Image classification task wherein the objective is to classify an image and associate it with a label. The task is to enable farmers to quick identify if the the image is affected by one of the diseases for which there are examples there are in the dataset, and leaves which are healthy.

About the Dataset

This dataset is provided in the form of TFRecords and Images. There are a total of 21,367 Training Images and 5 Classes. These images were collected on mobile cameras, and would be interesting to see the differences in image statistics due to this.

The classes are as follows:

0: Cassava Bacterial Blight (CBB)
1: Cassava Brown Streak Disease (CBS**D)
2: Cassava Green Mottle (CGM)
3: Cassava Mosaic Disease (CMD)
4: Healthy

Action Plan

I intend to work with Remo to visualize, create train-valid splits and also perform Exploratory Data Analysis easily, and outline how it could be useful to perform this.

The initial plan for the model exploration (Subject to change, through iteration)

  1. Creating a supervised transfer learning baseline
  2. Creating a self-supervised transfer learning baseline
  3. Self-supervised learning on the training set + subsequent finetuning.

I will be updating this thread as I progress, and adding insights and work towards a good position in the competition.


Model Type of Training Training Accuracy Validation Accuracy
ResNet-18 Supervised 85% 80%
ResNet-34 Supervised 80% 78%
ResNet-50 Supervised 90% 86.8%
Vision Transformer Supervised 89% 82%
BYOL (ResNet-50) Self=Supervised 90% 81%
SwAV (ResNet-50) Self=Supervised 71% 71%

@Harsha Love this!

What do you think about setting up a repo to share the code throughout the competition?

1 Like

That’s a great idea @andrea, will make a repo for the same soon!

Approach 1: Supervised Transfer Learning Baseline

In this approach I will be fine-tuning the following architectures which have been pre-trained on the ImageNet dataset or JFT Dataset (Visual Transformer). This process is called Transfer learning, and it involves freezing the weights which means they are not modified at training time.The last layer is modified to the number of classes and the objective of the downstream task.

In this case our task is Image Classification and the number of classes is 5. So the last layer is replaced with a MLP with output size 5.

The models being explored are as follows,

Figure 1: ResNet18 Architecture | Source

Figure 2: Vision Transfomer | Source

Approach 2: Self-Supervised Transfer Learning Baseline

In Supervised Learning there is s fixed mapping with labelled data between Data X → Label Y.
In an Image classification task, the data X corresponds to an Image and the data Y is the class label of the image.

Supervised learning paradigms require large amounts of annotated data, which is difficult and expensive to procure. While large datasets like ImageNet and MS COCO are available, adding becomes exponentially harder and the returns decrease.

Self-supervised learning is a paradigm of learning algorithms which attempts to cast the task which is intrinsically dependent on the properties of the data. This includes tasks like predicting a missing patch, predicting rotation and colour. This enables the algorithm to learn correlations that are related to the properties of the data distributions.

There are several approaches to this which include SimCLR, MOCO and SwAV. These are trained using contrastive objectives which learn from various augmentations applied to the Images.

The representations learned by self-supervised approaches can then be used for downstream tasks, similar to supervised pre-training. The network trained is used as a backbone and the output is passed to a fixed size MLP which predicts the class.

In this approach, I will explore methods which have been trained on large datasets like ImageNet, ImageNet21K etc and use them as a backbone and feed it into a MLP to predict the classes of the images passed to it, which is then fine-tuned. PyTorch Lightning provides an excellent framework for extending such architectures.

The advantage of this approach is that it does not have the overhead of acquiring large amounts of annotated data and gives greater control in improving the training process via new augmentations, learning objectives etc.,

Figure 3: SwAV architecture | Source

Figure 4: SimCLR architecture | Source

Approach 3: Self-Supervised learning on the Cassava Training Set

This approach is similar to the previous description only that the self-supervised method itself will be trained from scratch on the Cassava Dataset provided in the competition and then will be subsequently used to classify the downstream classes.


Approach 1: Supervised Transfer Learning Baseline

During this run I have worked with the ResNet-18 and ResNet-34 architecture. These models having been pre-trained on ImageNet would be fine-tuned on the cassava dataset.

A stratified split by class is used to create train and validation data. This is then exported from Remo to a CSV for training.


The ResNet 18 model was written using PyTorch Lightning. The pre-trained model from torchvision was used as a backbone, and the last layer was replaced for fine-tuning . Lightning allows for flexible and maintainable code, and easy access to methods like “ddp_sharded” which allowed for larger batch sizes and faster training.

The training accuracy achieved after 108 epochs was : ~85%
The validation accuracy achieved after 108 epochs was : ~80%

It used the AdamW optimizer with CosineAnnealingLR Scheduler. The maximum learning rate was set to 0.001.


Figure 5 : ResNet-18 Training Accuracy Plot

Figure 6 : ResNet-18 Validation Accuracy Plot


The ResNet 34 model was similarly implemented using PyTorch Lightning, and a pretrained model from torchvision.

The ResNet34 Model was trained for 50 epochs, not using the ‘ddp_sharded’ method.

It used the AdamW optimizer with CosineAnnealingLR Scheduler. The maximum learning rate was set to 0.01.

train_acc (2)
Figure 7 : ResNet-34 Training Accuracy Plot

valid_acc (2)
Figure 8 : ResNet-34 Validation Accuracy Plot

The training accuracy achieved after 108 epochs was : ~80%
The validation accuracy achieved after 108 epochs was : ~78%


The ResNet50 model was implemented using PyTorch Lightning, and a pre-trained model from torchvision.

The ResNet50 Model was trained for 30 epochs, using the ‘ddp_sharded’ method.

It used the AdamW optimizer with CosineAnnealingLR Scheduler. The maximum learning rate was set to 0.001.

Figure 9 : ResNet-50 Training Accuracy Plot

Figure 10 : ResNet-50 Validation Accuracy Plot

The training accuracy achieved after 108 epochs was : ~90%
The validation accuracy achieved after 108 epochs was : ~86.8%

Improvement points and takeaways:

  1. Using newer forms of augmentations enabled by libraries like Albumentations and Kornia. This would enable for strong data augmentation techniques, helping the model generalize further to the dataset.

  2. Iterate with optimizers, currently an improved variant of Adam – AdamW was used with a specific learning rate and learning rate scheduler. Experimenting with strategies with other optimizers such as AdaDelta, learning rates and schedulers like ReduceLROnPlateau would be another direction.

  3. Dataset cleaning, similar to some efforts to remove noisy labels via the use of Exploratory Data Analysis, Label smoothing and TTA.


Vision Transformer

Using the model ported to PyTorch by Ross Wightman in pytorch-image-models and the structure provided by PyTorch Lightning, the next architecture selected for fine-tuning was the Vision Transformer model.

It was trained on a single Tesla T4 GPU with the following configuration:

Batch Size: 50,
Optimizer: Adam
Learning Rate : 2e-05
Epochs: 5

This model performed the best amongst the attempted supervised approaches, shown by the following plots:

Training Accuracy
Figure 11 : Vision Transformer Training Accuracy Plot

Validation Accuracy
Figure 12 : Vision Transformer Validation Accuracy Plot

This approach shows the stronger generalization capability present in pre-trained Vision Transformers, which is due to the use of the Multi-Head Self attention in Transformers.

Due to the nature of the dataset, the disease areas are localized in the image and this is captured much better through the attention based model.

Improvement points and takeaways:

The points remain the same as those mentioned during the ResNet runs, but the main path of exploration would be introducing newer forms of augmentation while attempting to improve the validation score on this model. There is also a good chance this is the limit of standard transfer learning without changing properties of the dataset. (i.e removing mislabelled images, noisy images etc.,)

1 Like

Self-Supervised learning on the Cassava Training Set

Large scale annotated datasets are few and far between. They are difficult to acquire, and expensive to create. Adding additional such example imbalances the data distribution, and maintaining the same becomes even more challenging.

One of the key approaches to overcoming the need for labelled data is the use of techniques which use self-supervision to learn directly from the data.

For the Cassava dataset, I attempted to experiment with two techniques which use contrastive learning as a preset task to train a network which can then be fine-tuned to train a linear classifier.

BYOL: Bootstrap Your Own Latent

In this method by researchers from DeepMind and Imperial College London, was trained on the Cassava Dataset through a variety of augmentations.

Typically these models are trained for ~1000 epochs with a large batch size ~4096. But given resource and timing constraints – the first training was done for 81 epochs with batch size of 256.

The loss graph for this was

Figure 13: Loss graph of training BYOL on Cassava Data

The model trained was done using the BYOL-Pytorch repo in conjunction with PyTorch Lightning and the model used was ResNet50.

Following this, I saved the model and then ran replaced the final layer training it to perform linear classification on the Cassava Dataset.

The Resultant plots were as follows:

Training Accuracy (4)
Figure 14: ResNet50-BYOL Training Accuracy

Validation Accuracy (4)
Figure 15: ResNet50-BYOL Validation Accuracy

The training accuracy achieved after 30 epochs was : ~90%
The validation accuracy achieved after 30 epochs was : ~81%

SwAV: Unsupervised Learning of Visual Features by Contrasting Cluster Assignments

I similarly experiment with SwAV, a new technique released by Facebook Research which aims to learn visual features from images via contrasting cluster assignments.

This model was trained and subsequently finetuned on the dataset.

The plots for the same are given by:

Training Accuracy (1)
Figure 16:

Validation Accuracy (2)
Figure 17: ResNet-50-SwAV Validation Accuracy

The training accuracy achieved after 10 epochs was : ~71%
The validation accuracy achieved after 10 epochs was : ~71%

Improvement points and takeaways:

  1. Going through research in Self-Supervised learning was a great highlight regarding the challenges faced by the field currently. The requirement of large batch sizes, and significant training time is an indication of desirable improvements.

  2. Since self-supervised learning relies heavily on in-distribution features, models pre-trained on ImageNet might not have scaled appropriately with the Cassava dataset due to a marked shift in domain of examples.

  3. There was improvement in the training accuracy once BYOL was trained from scratch on the Cassava dataset and subsequently fine-tuned, this looks like it would be a favourable approach rather than direct finetuning.

  4. One way to ensure that your experiment is in sync with the large batch sizes used by the paper is to enable “accumulate_gradients” in PyTorch Lightning. This will increase amount of time taken to train one epoch update, but would be more consistent with batch size parameters of the paper.

  5. The performance of the models indicates that unless in a true un-labelled dataset scenario, with limited amount of data being present — self-supervised learning is undoubtedly a very exciting prospect, but in case there are present domain labels, supervised (currently) out-performs them.

  6. The augmentations are extremely important and introducing a wider variety has improved the performance of the methods.

  7. Try and work with different optimizers, learning rate schedules.