What if you don’t have a dataset of 300M images to train your vision transformer on? get some help from the good ol’ CNNs via distillation!
It seems that the title of the Transformer architecture paper is resonating more and more through our minds recently. Is attention really all we need? For some years now, it seems clear that the NLP community believes so, with transformers being the key component of SotA architectures, in language modeling, translation or summarization tasks.
A few months ago, a transformer-based architecture for image classification, proposed in this paper by Google, has outperformed SotA CNN-based architectures. One of the inconveniences of this architecture is the (understandably) large number of parameters and resources required to train.
In a fresh paper, Facebook has shown that a large number of parameters and the obnoxious training resources (the unavailable JFT-300M ImageNet and the TPU grid that was used) are sufficient, yet not necessary for a good performance.
Summarily, this paper aims to reduce the total number of parameters and required training resources by compensating it with an additional distiller token and a better training strategy. The DeiT-B has the same number of parameters as ViT-B, but the two other models go towards a lower number of parameters (S-small, Ti-tiny), as opposed to the BERT-inspired ViT parameter models (L-large, H-huge). The pre-training is done using 224x224x3 RGB images, and the fine-tuning is done using 384x384x3 RGB images.
The architecture of DeiT is identical to the one of the ViT, where an image is first decomposed and projected into patch tokens, then ran through layers of transformer encoders (which are virtually a sequence of self-attention and MLPs, as depicted in both papaers). The difference is in the newly introduced distillation token, which results in an additional component to the global loss function, based on which the weights are updated.
The distillation token aims to minimize the difference between the prediction of the ViT — the student network — and the predictions of a SotA CNN — the teacher network. As the authors note, using a convolutional network as a teacher network helps to address the problem of inductive bias in transformer networks.
Compensating for this bias is important because transformers lack intuitive priors about the structure of the data. Oppositely, convolutional and recurrent networks have a high inductive bias, as they are aware of an “ordering rule” within the input tensor or sequence.
No inductive bias also implies that it has to look at more data to learn. Transferring part of the bias from a CNN teacher to a transformer student should, in theory, reduce the required training time, especially for pre-training.
The global loss of the DeiT architecture is a weighted sum between the cross-entropy loss from the ViT and another loss based on the distiller token.
The first variant of the loss function is inspired by the original formulae proposed in the distiller paper and makes use of the Kullback-Leibler divergence to express the loss caused by the difference between the distributions predicted by the student and the teacher.
The other variant, proposed by the authors is a simplified version, where the teacher prediction is treated the same way as the ground truth, and a cross-entropy loss function is used instead of the KL divergence. Moreover, both components are treated with equal importance, encouraging the student to learn equally from both sources.
- λ — Balancing coefficient. Adjusts which of the two composing loss functions is more important for the overall loss.
- τ — Distiller “temperature”. Adjusts how soft the labels are. A higher value results in a softer probability distribution of the labels. A value of 1 equals a hard distillation.
- Lce and KL — The cross-entropy and the Kullback-Leibler losses.
- ψ — The softmax function.
Authors note that the two tokens converge towards different values, but become more and more similar with respect to the number of transformer layers. It is important to note that they do not converge further than 0.93, pointing out that there is certainly some information to be learned from both sources.
Adding the distillation token and using the hard-distillation loss, the authors reported an ImageNet top-1 accuracy of 83.4% for pretraining and 84.2% for fine-tuning.
Although ViT is being glorified for outperforming convolutional networks, this only happens when pre-training on large datasets, like the publically unavailable JFT-300M.
When training on smaller datasets, ViT has worse performances than the top CNNs. On the other hand, the curve of the DeiT model closely approaches the one of EfficientNet, when looking at the ImageNet1k dataset. It is clear that adding distillation (red line) has a positive impact on the accuracy while having a trivially low impact on the throughput (when compared to the model without distillation — the maroon line).
In order to compensate for a reduced training dataset, authors make use of data augmentation. Moreover, various optimizers and regularization techniques were tried, in order to obtain the best set of hyper-parameters, to which transformers are usually highly sensitive. The optimal configuration was obtained by using an ablation study, which can be found in the DeiT paper.
In conclusion, the DeiT setup goes around ViT’s “shortcut” (yes, pun intended) of pre-training on the JFT-300M dataset and provides more realistic and available alternatives to obtaining SotA performance on ImageNet using almost the same architecture.
A great video resource for the DeiT paper is available here.
 H. Touvron et. al., Training data-efficient image transformers & distillation through attention (2020)
 A. Dostovitskiy et. al., An Image Is Worth 16x16 Words: Transformers for Image Recognition at Scale (2020) ICLR 2021 (under review)
 G. Hinton et. al., Distilling the Knowledge in a Neural Network (2014), NIPS 2014 Deep Learning Workshop
 S. Abnar et. al., Transferring inductive biases through knowledge distillation (2020)
 A. Vaswani et. al., Attention is All You Need (2017), Proceedings of NIPS2017
 J. Devlin et. al., BERT: Pre-training of deep bidirectional transformers for language understanding (2018)