Unified Distillation: Supervising the Student to Make Fewer Mistakes

Knowledge distillation has become a popular task in modern deep learning applications by performing knowledge transfer from a cumbersome neural network commonly called “teacher model” to a much smaller network called “student model”. In traditional knowledge distilling process, there are usually two objectives for training the student model, namely hard target and soft target. However, sometimes it’s hard to find a trade-off between them. We unify the two objectives into one making it easier to perform knowledge distillation, and propose a novel distilling method called “Unified Distillation” to supervise the student to make fewer mistakes. The model can correct the wrong predictions according to the hard target, and maintain the advantage of knowledge distillation. Although our method can be used in almost all fields suitable for knowledge distillation, we choose neural machine translation as a study object for its complexity. We conducted experiments on three neural machine translation tasks, using a finetuned language model BERT as the teacher, and a Transformer base model as the student. The experimental results indicate that our method is better than traditional knowledge distillation method.


Introduction
Nowadays, deep learning has become ubiquitous in various research fields such as computer vision (CV) and nature language processing (NLP). However, researcher's pursuit of high-performance computing makes deep neural networks more and more complex so that these networks need more storage and computing resources. As a result, the extremely large model size prevents these neural networks from being deployed and operated in devices with poor computation power such as mobile phones or other handheld devices. So knowledge distillation (KD) came to the rescue of this problem and became one of the most widely-used model compression techniques. The main idea of KD is transfer learning, training a simplified model commonly referred to as the student model to mimic the behavior of the complex pretrained model commonly called the teacher model.
Cristian Bucilua called the dataset for training the student model a "transfer set" instead of a general training set [1]. Although the transfer sets do not require having ground truth labels, researchers generally have the transfer sets with ground truth labels available and these labels offer much help for model performance improvement [2][3] [4][5] [6][7] [8]. In this situation, there are two objectives that the student model needs to aim at, one is ‫ܮ‬ ௦௧ , which measures the difference between the output probability distributions of the teacher and that of the student, and the output probability distribution of the teacher are commonly called "soft target"; the other is ‫ܮ‬ ௗ , which measures the difference between one-hot encoding label of ground-truth called "hard target" and the output probability distributions of the student. And the both objectives are combined as Equation (1).
(1) The ratio of two objectives is tuned with a hyper-parameter ߙ. Generally, ߙ is set to 0.5 by default, but the best practice is to set ߙ to a larger value to emphasize ‫ܮ‬ ௦௧ [3].
Equation (1)  However, it appears to be inelegant and casual to simply add ‫ܮ‬ ௦௧ and ‫ܮ‬ ௗ in a fixed proportion. If the teacher's prediction is consistent with the ground-truth label, it seems unnecessary and redundant to use ‫ܮ‬ ௗ . Otherwise, the student may get confused on which one of the two objectives to aim at. Just as an old proverb goes, "Person who chases two rabbits catches neither. " We propose a novel distilling method called "Unified Distillation" in order to disambiguate the two training objectives. Our method is shown in Figure 1. We first added the hard and the soft target distributions together as one combined target. Notice that the combined target is not a probability distribution before normalized, because the sum of all probability values doesn't equal to 1. The max probability value of the combined target can be very large or not very large depending on whether the teacher's prediction is in consistent with the ground-truth label or not, so we need to rectify it by rescaling the max probability value to a fixed value, which is given by a hyper-parameter ܲ ௫ . Rectifying the combined target serves two purposes: (1) We can convert it into a normalized probability distribution; and (2) We can make the target distribution smoother. Thus, not only can we maintain the supervision of ground-truth label, but also highlight the information that the small probability values contain. Negative labels with small probability values contain plenty of information, which can be viewed as knowledge of the teacher. The corresponding probability values of some negative labels may be greater than that of other negative labels. For example, in an image classification task, the similarity between a truck and a garbage truck is much greater than that between a truck and a carrot [3].   In general knowledge distillation, we optimize student's output distribution to minimize L in Equation (1). When ߙ is set to 0.2, the influence of hard target is emphasized. Although the mistakes made by the teacher can be corrected, when the teacher's prediction is consistent with the hard target, the valuable small probability values are diminished, and the student is hard to obtain knowledge in this way. When ߙ is set to 0.5, the influences of both soft and hard target are equal. The mistakes made by the teacher can be corrected, but the probability value of the mistaken prediction label only slightly falls short of that of the correct label, and the student may get confused on which labels to follow. If the teacher's prediction is consistent with the hard target, the valuable small probability values are also diminished but not as severe as the case when ߙ is 0.2, which may hamper the student from learning knowledge from teacher. When ߙ set to 0.8, the soft target is highlighted, but the mistakes cannot be corrected. In knowledge distillation, the teacher is much more complex and powerful than the student, so the proportion of correct predictions is always much higher than the incorrect ones. Consequently, the best practice is to set ߙ to a higher value to prefer the major correct proportion. Obviously, one fixed value of ߙ cannot coordinate the contradiction among the two targets and the correct or incorrect prediction. Our Unified Distillation fabricates a virtual "teacher" which always makes correct predictions by combining both soft and hard targets and forming a unified target. The Unified target is always correct and does not diminish the valuable small probability values regardless of whether the teacher's prediction is correct or not.

Figure 2. Comparison between our method (middle) and general KD method (right) in both cases.
Although our "Unified Distilling" is a general method and can be used in almost all fields suitable for knowledge distillation, we choose neural machine translation (NMT) as the object of study for its complexity. And we believe that it works well in other fields if it works in NMT tasks.
We used a finetuned language model BERT as the teacher, which can generate sequences of word probability values as soft targets for the training samples in KD process, and a Transformer base model as the student, which can learn from the teacher's outputs. Experimental results indicate that our method outperforms the general KD method, the Transformer baselines, and the proposed method that dynamically adjust distillation temperature [9].  [3]. Bucilua et al. firstly attempted to use the outputs of a network to train another neural network [1]. Ba and Caruana further used the logits in a cumbersome neural network as supervision to train a much smaller network so that the model size can be reduced [10]. Hinton et al. systematically summarized a widely used method of KD by dividing the logits by an extra hyper-parameter called "temperature" so that the output distributions can be smoothed [3]. After that, KD has become one of the important methods of model compression.

Related works
Many studies sprung up to enrich and improve the supervisory information of training the student models for distilling knowledge efficiently in both CV and NLP tasks [1][4] [12]. Some studies showed that the teacher's output in the intermediate hidden layers can also provide essential information [12]. Tiancheng Wen et al. proposed a knowledge adjustment method that generated better supervision for student model, increasing accuracy on image classification datasets [9]. In NLP field, some works utilized the knowledge from the teacher at different levels jointly, such as embedding, representations output by hidden layers, the output of self-attention modules as well as the final predictions of the teacher model [13] [14].
NMT task is to predict the text sequence, in which each word is based on previous words, and finally generate the whole sequence. So NMT models are always so large that KD receives much attention. Yoon Kim et al. conducted KD not only at word level, but also used beam search at sequence level, compressing a LSTM based NMT model [8]. Yen-Chun Chen et al. proposed a new Conditional Masked Language Modeling (C-MLM) method, which enabled fine-tuning pre-trained BERT on a target dataset and leveraging it into text generate tasks such as NMT [7].

Methodology
We first detail our Unified Distillation method in 3.1, and then present the application of Unified Distillation for NMT Tasks in 3.2.

Unified Distillation
There are two steps in Unified Distillation method. We first compute the combined target by directly adding probability distribution namely soft target from the teacher and one-hot label together as Equation (2), where ‫)ݔ(ݖ‬ and ‫‬ ௧ ‫)ݔ(‬ denote the combined target and the output probability distribution of the teacher on sample x in a batch, respectively. ‫)ݔ(݈‬ is the one-hot label of sample x, and the i-th term in ‫)ݔ(݈‬ is defined as Equation (3), where ‫)ݔ(ܿ‬ denotes the correct class of sample x. The second step is to rectify the combined target, as is shown in Figure 1. Obviously, the maximum value of ‫)ݔ(ݖ‬ corresponds to the correct class. We rectified the combined target to make Equation (4) hold by tuning ߬ ௫ , where ߬ ௫ denotes the temperature parameter for sample x, and hyper-parameter ܲ ௫ is set in the range of 0.5 to 1.  (4) has no closed-form solution, ߬ ௫ cannot be computed directly. We built a small machine learning model with the objective of minimizing loss function ‫ܮ‬ ఛ in order to obtain approximate solution of ߬ ௫ to satisfy Equation (5), where N is batch size of samples, and MSE stands for the mean squared error loss function. Algorithm 1 gives the process of solving ߬ , which consists N elements of ߬ ௫ in a batch. This algorithm doesn't significantly increase the time of training the student model for we set max_iter to a small value like 10.  (2) calculate ‫ܮ‬ ఛ using Equation (5) The solved sample-wise temperature ߬ ௫ can be applied to Equation (6) to compute the unified target ‫‬ ௨ௗ ‫)ݔ(‬ for sample x, and the student model can be trained using loss function given by Equation (7), where ‫‬ ௦ ‫)ݔ(‬ denotes the output probability distribution of the student, and KL stands for Kullback-Leibler divergence.  In order to verify the effectiveness of our proposed method, we performed Unified Distillation on NMT tasks. We followed the basic teacher-student framework in [7], applying a finetuned language model BERT as the teacher and a Transformer base model as the student, except that the general KD method was changed into Unified Distillation method illustrated in Section 3.1, as shown in Figure 3.

Unified Distillation for NMT tasks
We chose this teacher-student framework for two reasons: (1) It achieved competitive results on several NMT tasks, especially on IWSLT German-English dataset; and (2) The general KD method of combining both soft and hard targets together was applied in this framework, so it's easy to compare our method with the general KD method to verify that our method is superior.
Generally, BERT is trained with a generative objective via MLM during the pre-training stage, and this training objective forces the student model to learn essential, bidirectional, contextual knowledge. Because BERT is not auto-regressive enabled, it cannot be directly applied to auto-regressive NMT in practice. Y. C. Chen proposed a novel method called C-MLM to fine-tune the pre-trained BERT model, which is similar to MLM but requires additional conditional inputs [7]. We applied C-MLM to fine-tuning the pretrained BERT-base-multilingual-cased model on the target datasets. After the BERT model was finetuned, we set it as the teacher that can generate probability distributions of tokens (soft targets) for the training samples, and set the Transformer base model as the student. Because both the left and the right contexts of each token can be fully exploited in the finetuned BERT model, the single-directional limitation of general Seq2Seq training framework can be compensated, and the student can effectively learn more information from the teacher's outputs. More details can be found in [10] and [7].
To compute the soft targets of the finetuned BERT teacher, we applied the circular mask strategy proposed in [7], and obtained probability distributions predicted by the BERT teacher for each output target tokens in masked position as soft targets. The whole process of computing soft targets by circular mask is shown in Figure 3.

Input Source Sequence
Masked Output Target  In an NMT task, a pair of source and target sentences is processed as a sample, and each sentence or sequence is consist of tokens. We perform Unified Distillation for each token at a time in one target sequence. Take the third token as an example, as is shown in Figure 4, we combined the probability distribution of masked token t 3 computed by the teacher and the known one-hot label to obtain the combined target. Then the temperature value for the token t 3 can be solved using Algorithm 1 so that the combined target can be rectified. Thus, the unified target is obtained, which can supervise the training of the student model. The subscript character n denotes the length of the target sentence. A circular mask strategy is applied, and repeating in a training pair for 7 times to approximate the 15% masking rate. The masked pair is input to the finetuned BERT in every iteration, and the probability distribution of each masked token is computed and outputted.
At inference stage, the BERT teacher is not used so that the decoding speed is the same as the student Transformer model alone.

Experiments and Results
We conducted experiments on three NMT tasks, using datasets and basic experimental settings described in 4. 1. Detailed experiments and results are given, and BLEU scores are presented for evaluation of NMT.

Datasets and Basic Settings
We considered two relatively small-scale datasets, IWSLT14 German-English (De-En, 160k training samples) and IWSLT15 English-Vietnamese (En-Vi, 113k training samples) [15] and one relatively largescale dataset, UM-Corpus (En-Zh, 2.2M training samples) [14]. For IWSLT14 De-En, we followed the pre-processing steps presented in [6]. For IWSLT15 En-Vi, we used tst2012 as dev set and tst2013 as test set. For UM-Corpus En-Zh, we used UM-Testing dataset that contains 5,000 samples as the test set followed the settings of [14]. We adapted a pre-trained BERT-base-multilingual-cased model by fine-tuning it with C-MLM as the teacher, and a Transformer base of 6 encoder/decoder blocks and 8 attention heads attention blocks as the student [3], where the size of hidden feed-forward layer is 2048 and the size of hidden size is 512. BERT  Table 1 were applied to our experiments for three datasets. Two GeForce GTX 1080ti GPUs were used in our experiments.
For all experiments, the learning rate schedule is set to Equation (8) according to [3], where ݀ ௗ is the hidden size, and ‫ݎݐ݂ܿܽ‬ is a hyper-parameter. Before training, pairs of source and target sequences in training set are preprocessed and batched sequences of similar length, and batch size is counted by the number of tokens instead of the number of pairs. In KD process, the BERT's prediction logits of the training data were computed, and we used top-K distillation [20] where K is set to 8 through all the experiments. The temperature value ߬ ௫ for each token is initialized to 1.

Experimental Results
We evaluated Unified Distillation on IWSLT14 German-English translation, IWSLT15 English-Vietnamese translation, and UM-Corpus English-Chinese translation datasets. The evaluation results are presented in Table 2, Table 3, and Table 4. The value of hyper-parameter ܲ ௫ should be tuned according to the dataset. Table 2 indicates that on IWSLT14 German-English translation dataset with ܲ ௫ set to 0.75, our United Distillation method outperforms the results of Transformer baseline and Transformer trained with general KD method by 4.9% and 0.2% on development set, respectively; 5.5% and 0.8% on test set, respectively. Table 3 demonstrates that on IWSLT15 English-Vietnamese translation dataset with P ୫ୟ୶ set to 0.90, our United Distillation method clearly outperforms the results of RNN baseline and RNN trained using the general KD method presented in [7] by 19.0% and 10.9% on development set, respectively; 17.3% and 13.9% on test set, respectively. However, our United Distillation method is still at par with the results presented in the same literature. Because of the limited computing power on our available computer, we set the batch size much smaller than that of original experiment, perhaps that's the reason why our method does not achieve prominent results on this dataset.  Table 4 shows that on UM-Corpus English-Chinese translation dataset, our United Distillation method outperforms SMT baseline result presented in [14] by 0.78% with ܲ ௫ set to 0.80.

Effects of
We conducted three NMT tasks experiments with different values of ܲ ௫ , and the results are listed in Table 5.

More Experiments
More experiments were conducted to compare the Unified Distillation method with the general KD method using different temperature adjustment strategies. The results are listed in Table 6.
The results show that our Unified Distillation method outperforms the method using the fixed value 1 (the value for initializing all temperature values in our method) for each token as well as other dynamic temperature adjustment methods proposed in [9].

Conclusion and Discussion
We proposed a novel method called "Unified Distillation" to improve the general KD method by combining two objectives together and then rectifying it by rescaling the max probability to a fixed value.
In this way, the information of ground-truth label and the "dark knowledge" from the teacher model can be unified into one target, avoiding the dilemma of chasing two targets at the same time. We conducted NMT experiments on three datasets to validate our method. A finetuned language model BERT is adopted as a teacher, and a Transformer base model is adopted as a student. Experimental results indicate that our method provides a better approach for knowledge distillation applications.