Breast cancer diagnosis through knowledge distillation of Swin transformer-based teacher–student models

Breast cancer is a significant global health concern, emphasizing the crucial need for a timely and accurate diagnosis to enhance survival rates. Traditional diagnostic methods rely on pathologists analyzing whole-slide images (WSIs) to identify and diagnose malignancies. However, this task is complex, demanding specialized expertise and imposing a substantial workload on pathologists. Additionally, existing deep learning models, commonly employed for classifying histopathology images, often need enhancements to ensure their suitability for real-time deployment on WSI, especially when trained for small regions of interest (ROIs). This article introduces two Swin transformer-based architectures: the teacher model, characterized by its moderate size, and the lightweight student model. Both models are trained using a publicly available dataset of breast cancer histopathology images, focusing on ROIs with varying magnification factors. Transfer learning is applied to train the teacher model, and knowledge distillation (KD) transfers its capabilities to the student model. To enhance validation accuracy and minimize the total loss in KD, we employ the state–action–reward–state–action (SARSA) reinforcement learning algorithm. The algorithm dynamically computes temperature and a weighting factor throughout the KD process to achieve high accuracy within a considerably shorter training timeframe. Additionally, the student model is deployed to analyze malignancies in WSI. Despite the student model being only one-third the size and flops of the teacher model, it achieves an impressive accuracy of 98.71%, slightly below the teacher’s accuracy of 98.91%. Experimental results demonstrate that the student model can process WSIs at a throughput of 1.67 samples s−1 with an accuracy of 82%. The proposed student model, trained using KD and the SARSA algorithm, exhibits promising breast cancer classification and WSI analysis performance. These findings indicate its potential for assisting pathologists in diagnosing breast cancer accurately and effectively.


Introduction
Breast cancer remains a significant public health concern, with an estimated 287 850 fresh cases reported in the United States in 2022, accounting for 31% of all newly diagnosed cancer cases in women [1].Furthermore, breast cancer is responsible for an estimated 43 250 deaths, approximately 15% of all female cancer-related deaths.Therefore, the timely detection of breast cancer is imperative, as it facilitates prompt intervention and enhances overall outcomes.While non-invasive methods such as mammography commonly serve for initial screening, biopsy techniques are often necessary for confirmation.Histopathology images obtained from biopsies and stained with hematoxylin and eosin (H&E) are commonly employed for such analyses.Analyzing a high-resolution H&E-stained whole-slide images (WSIs) involves meticulously checking each region of a large WSI for possible tumors.However, accurately interpreting H&E-stained images can be challenging because of the intricate relationships between pixels within the images, necessitating frequent viewing at various magnification settings.This task demands substantial time and experience.
Although computer-aided diagnosis with traditional image processing simplifies the process, handling large WSI remains challenging.Pathologists must meticulously scrutinize each region; given the heavy workload, incorrect predictions are inevitable.The application of deep learning techniques, specifically convolutional neural networks (CNNs) capable of automatically learning relevant features from labeled data, has demonstrated notable success in classifying and identifying tumors in medical images.Additionally, the recent introduction of vision transformers (ViTs) [2,3] has yielded promising results in diverse computer vision applications, including image classification, object detection, and segmentation.ViT architectures can process images globally and capture long-range dependencies between pixels.It is an advantage over traditional CNNs limited to local processing.However, training on WSI poses a challenge due to the large scan size, requiring significant computational resources.Additionally, the region of interest (ROI) in WSI is often smaller, and accurately creating segmentation marks takes time for initial training.
A promising approach is to generate labels for ROIs in histopathology images.The classifier is then explicitly trained on these ROIs for the classification task.It eliminates the need for creating complex segmentation masks and, subsequently, splitting the WSI into patches for the classifier to predict the location of tumors.Further, in deep learning, a complex model often yields better results than a sparse model, albeit at the expense of higher computation time.Therefore, the complex model requires substantial time to process large WSI.In contrast to the complexity of the previous models, applying knowledge distillation (KD) [4] provides a solution by transferring expertise from a complex teacher to a simpler student model.The main objective of this paper is to develop a simple, high-performing, computationally efficient model for handling WSI at varying magnification factors.
• Propose both the student model and the teacher model.Then train the teacher model with the dataset of breast cancer histopathology images.• Implementing state-action-reward-state-action (SARSA) algorithm for dynamically computing temperature and a weight factor in KD to achieve high validation accuracy of the student model.• Apply a whole slide imaging (WSI) to the trained student model to find the coarse location of suspected regions in the given slide.
To meet the above objectives, the article is structured as follows: section 2 analyzes related works; section 3 describes the proposed methodology; and section 4 outlines the experimentation process.Section 5 presents the results and discussion of the study.Additionally, section 6 compares the proposed model with recent works.Finally, section 7 concludes the paper, highlighting the study's limitations and suggesting future scope.

Related works
The initial works that explored malignancy in histopathology images involved extracting relevant handcrafted features using traditional image segmentation, threshold methods, and wavelet transformations.These features were then employed to train machine learning classifiers for malignancy prediction.For instance, Spanhol et al [5] examined six feature descriptors and validated results across four classifiers.Similarly, Chattoraj and Vishwakarma [6] used the relief algorithm to extract discriminatory features, training a support vector machine (SVM) for binary classification.However, these methods cannot localize the tumor in WSI.
In contrast, recent approaches leverage CNN architectures for feature extraction, using pre-trained models or training from scratch.For example, Shallu and Mehra [7] utilized a pre-trained VGG-16 [8] with logistic regression, while more advanced methods involve hybrid architectures with multiple CNNs or employ transfer learning and fine-tuning.These modern strategies, often enhanced by data augmentation, improve model performance.Bardou et al [9] implemented a custom CNN with data augmentation, and Alom et al [10] introduced a hybrid model combining Inception v4 [11], ResNet [12], and recurrent neural network (RNN) [13].Although these models can localize the tumor, processing for each magnification requires more time and comparatively lower accuracy.Similarly, Boumaraf et al [14] employed ResNet-18 transfer learning with block-wise fine-tuning, demonstrating enhanced model accuracy.Nevertheless, the model is lightweight, but it requires extensive training to enhance its performance.Despite these advances, using relatively light models to explore malignancy in WSI has never been applied using KD.This strategy could enhance model performance, simultaneously diminishing computational demands.It holds promise as a prospective avenue for future research in the examination of extensive WSI.

Proposed methodology
This section briefly overviews KD and its implementation using the SARSA algorithm.Additionally, the study explains the two datasets used.

KD
KD is a methodology that utilizes a more advanced teacher model to train a simpler student model using tailored training data specific to the problem.For efficient training of the student model, combining Kullback-Leibler (KL) divergence loss [15] and cross-entropy loss (CE_loss) forms the overall loss function.The former loss measures the difference between the probability distributions of the teacher and student models, and equation (1) depicts the KL divergence ) . ( In this context, P t and P s represent the probability distributions of predictions generated by the teacher model and the student model, respectively, where i is the index corresponding to each class.Similarly, the temperature term in the KL divergence controls the softness of the teacher model's probability distribution, allowing the model to explore more during the training process.Specifically, the teacher model's logits (pre-softmax outputs) are exponentiated using the temperature parameter 'T' before applying the softmax function, as indicated in equation ( 2) logits i is the ith logit, and T is the temperature parameter.A high value of temperature (T) leads to a softened probability distribution, where probabilities are more evenly spread across classes, and the distinctions between them are less pronounced.Conversely, the CE_loss gauges the effectiveness of the model in classifying the training data and equation (3) represent CE_loss where y true is the true probability distribution of the labels and y pred is the predicted probability distribution of the label's output by the model.As the two losses measure distinct aspects of the model's performance, a scaling factor is commonly applied to the KL divergence loss to align its magnitude with the other loss term.Specifically, this involves multiplying the square of the temperature parameter by the KL divergence loss.Additionally, a weight factor influences CE_loss and KL divergence during training.A conventional method involves formulating the overall loss as a weighted combination of cross-entropy and KL divergence losses, with a single hyperparameter, often denoted as α, controlling the individual contribution.Equation ( 4) represents total_loss Ensuring optimal performance regarding validation accuracy and loss for the student model relies on selecting suitable temperature and alpha values.However, manually choosing these parameters can be time-consuming, often requiring extensive training iterations.To circumvent this labor-intensive approach, we leverage the SARSA reinforcement learning (RL) algorithm.It allows for the dynamic computation of temperature and α values, providing an automated and efficient means of determining the best settings in each epoch for the student model without manual intervention.

SARSA
RL techniques are frequently employed to determine the best solution to a problem.One such technique is the SARSA algorithm, a model-free RL algorithm that uses a table to represent the Q-value [16,17].The algorithm, beginning with an arbitrary Q-value function, employs an epsilon-greedy policy from the outset.
During each iteration, it selects an action based on this policy, observes the next state and reward, and updates the Q-value function for the current state-action pair.The fundamental concept behind the SARSA algorithm is to learn an optimal policy by iteratively improving the current policy through greedy actions.
Hence, a suitable reward function is necessary to identify the optimal temperature and α values that maximize validation accuracy and minimize loss.Therefore, to compute the reward function for SARSA, at each episode step, the total loss is calculated for a given temperature and α as per equation (4).Similarly, training accuracy, validation accuracy, and validation loss are calculated on the student model.Finally, compute the overfit_penalty using equation ( 5) and then derive the reward function using equation ( 6) The resulting reward is used in an RL algorithm to encourage high validation accuracy and low validation loss while penalizing over-fitting.Figure 1 explains the different stages involved in KD.
Similarly, algorithm 1 represents the pseudo-code for the implementation of KD.It initializes the current temperature at 5 in the range of 1-10 and alpha at 0.1 in the range of 0-1.It creates a SARSA table to store the Q-values for each state-action pair and sets the learning rate (LR), η = 1 × 10 −4 , and exploration rate ε = 0.1.It then iterates through 50 episodes, initializing the current state during each episode.Using an ε-greedy policy based on the SARSA table, the algorithm selects an action and performs training on the student model using the current temperature and alpha values.The algorithm computes the total loss, i.e. training loss, validation loss, and validation accuracy and adds an episode reward based on these metrics.The SARSA table gets updated by considering the observed reward, the new state-action pair, the current temperature, and adjusting the alpha values based on the chosen action.The algorithm selects a new action and updates the exploration rate.Finally, the algorithm returns the temperature and α values that maximize validation accuracy and minimize validation loss.

Teacher model
The transformer based on the Swin transformer, Swin_Base_Patch4_Window7_224 [3], is considered a teacher model.The model undergoes pre-training on a substantial image dataset, like ImageNet-22k, utilizing images sized at 224 × 224, and attains a top-1 accuracy of 85.2%.The model is divided into phases, each containing a series of Swin transformer blocks.The sequence of operations in each block is depicted in equations ( 7) and ( 8) Here, x can be either a sequence of image patches or a feature map.The block itself is composed of two main components: the multi-scale self-attention (MSA) and the feed-forward network (FFN).MSA aids the model by selectively focusing on various input segments by comparing attention scores between each pair of positions in x across multiple scales.The Swin transformer block employs a hierarchical approach to attention computation, which involves computing attention locally within small groups of patches and then between the resulting groups at different scales.Two versions of this hierarchical approach are available, namely window-based MSA and shifted window-based MSA.On the other hand, the FFN is responsible for introducing a non-linear transformation to the input features, which aids in capturing more complex patterns in the data.Finally, the 'LayerNorm' operation stabilizes the training process through standardizing the input features within the batch dimension, helping to ensure consistent performance.
After the input is passed through multiple stages of Swin transformer blocks, it is then fed into the final layer known as the 'head' .Typically, the head is a dense layer that performs a task-specific transformation of the input features.In the case of binary classification, the output of the head layer is modified to two, which allows the model to output a probability distribution over the two classes.

Student model
In KD, the student model typically has fewer parameters compared to the teacher model.However, the performance of a smaller student network trained with KD deteriorates when there is a significant disparity between the student and teacher networks.[18].Hence, a relative sparse model compared to the teacher model based on Swin transformer v2 [19] was created.Table 1 presents a comparative analysis of the teacher and student models in terms of size and throughput.The results reveal a significant difference in size, with the student model being approximately one-third the size of the teacher model.Similarly, the student model requires fewer floating-point operations per second, almost one-third of those needed for the teacher model.Furthermore, while the student model exhibits a throughput comparable to the teacher model, it demonstrates a 2.3-fold improvement, rendering it more efficient in processing speed.

Datasets descriptions
This work considers two different datasets.One database contains ROI of breast cancer histopathology images in the benign and malignant classes.The second dataset includes ROIs and WSI, categorized as normal, benign, in situ carcinoma, and invasive carcinoma.
The first dataset used in this work is the Break His dataset1 [5,20], which comprises 7909 open surgical biopsies of H&E-stained histopathology images broadly categorized as malignant and benign across four magnification factors.Initially, stained slides underwent scanning at a magnification factor of ×40.Subsequently, the authors of the dataset magnified the ROIs in the image to ×100, ×200, and ×400 to identify tumors, then resized all slides to 700 × 460 pixels.Among these, 2480 belong to benign and 5429 to malignant, with the removal of 123 duplicate images from the malignant category.The dataset was further spit into training and validation.Table 2 summarizes training and validation datasets for each category on different magnification factors.A total of 6229 images were used for training, while the validation dataset comprised 1557.The dataset played a primary role in training the teacher and student models on four different magnification factors.
The second dataset is the Breast Cancer Histology (BACH) dataset [21], which comprises two databases: photos and WSI.In photos, each image is labeled as one of the normal, benign, in situ carcinoma, or invasive carcinoma categories, and each category consists of 100 images.In WSI, there are ten images and their corresponding masks.Each mask may contain more than one category.It was also noted that the size of the WSI varies from image to image.Here, WSI images are used to predict the suspected regions of malignancy.

Data preparation and pre-processing
Utilizing the BreakHis dataset trains both the teacher and student models.All images in the dataset are initially resized to 224 × 224 using bi-linear interpolation.However, since the dataset has limited data, on-the-fly data augmentation techniques enhance the model's performance.Three augmentation techniques come into play during training: random rotation, horizontal flip, and vertical flip, each with a probability of 0.75 [22].This probability allows for including the original image and the three augmentation techniques, resulting in any of the four possible transformations generated during training.Additionally, the images are normalized using standard normalization [23] to improve the training process further.

Experimental settings 4.3.1. The teacher model training
The teacher model undergoes training on the BreakHis dataset through transfer learning.A small LR maintains the model's stability but may take longer to train.However, the benefits of transfer learning often outweigh this cost.Therefore, the chosen LR is 1 × 10 −4 , and it undergoes a halving process every 15 epochs, with a total of 50 epochs utilized.A large batch size reduces training time, but its magnitude depends on the available memory in the graphical processing unit.A batch size of 32 is applied to both teacher and student models to maintain consistency in the training process.Similarly, the optimizer AdamW [24], as utilized in the original paper on the Swin transformer, is applied in this context.In addition, the CE_loss function is used to calculate the loss.At the end of the epoch, the model validates against a validation dataset.The model's state, corresponding to the maximum validation accuracy and the minimum validation loss, is used as the teacher model to train the student model using KD.

The student model training
Figure 1 explains the training process of the student model.The model undergoes training for 50 episodes using the SARSA algorithm, and a similar teacher training strategy has followed, expecting a constant LR.

WSI processing
The WSI is typically stored in Aperio's proprietary format as a series of TIFF images, with dimensions varying from one image to another.To standardize the images and facilitate their analysis with the student model, they are resized to 2464 × 1792, a multiple of 224, the input size of the student model.The resized images are then converted into PNG format, as illustrated in figure 2(a).Similarly, figure 2(b) visually represents the WSI after overlaying a mask and a grid of size 224 × 224 pixels.This process allows for the identification and highlighting of three types of tissue in the image: benign (red), in situ carcinoma (green), and invasive carcinoma (blue).The resulting masked images reveal the ROIs in the tissue sample.Some grids contain multiple ROIs, while a few patches in the bottom-right corner do not contain histopathology images.Additionally, some patches in the bottom left corner only partially cover the tissue sample.Then, each WSI image is divided into patches of size 224 × 224 to test for malignancy using the student model.This process results in a total of 88 patches for each image.
However, the student model can perform binary classification, whereas WSIs contain three categories.Hence, two strategies were implemented to overcome these impairments.The first strategy involved treating invasive and in situ carcinomas as a single entity and categorizing them as malignant.The second strategy was to set a threshold for the validation loss.If the loss exceeded 0.5, the model discarded that slice and considered it a normal or partial fill region.Nevertheless, if the loss was less than 0.5 and the difference in confidence probability was more than 70%, the patch may contain both benign and malignant cells.In such cases, the model retained the patch and considered it for further analysis by magnifying it twice.The experiment was conducted on an Intel® Xeon® Gold 6226 R, 1 TB HDD, 64 GB DDR4, 2933 MHz, ECC, NVIDIA® RTX™ A5000, which has 24 GB of GDDR6 memory and Ubuntu® Linux® 20.04 LTS.The implementation utilized the PyTorch package [25] and performance matrices are analyzed using scikit-learn [26].

Results and discussion
This section presents an analysis of the performance of both the teacher and the student models concerning accuracy, loss, precision, recall, F-score, Matthews correlation coefficient (MCC, and confusion matrix for binary classification.Furthermore, the proposed methodology is applied to WSI to evaluate the accuracy and latency of both models.Figure 3 displays the accuracy and loss of the teacher model throughout each epoch of training and validation on the BreakHis dataset.The model attained an accuracy of 98.91% and recorded a loss of 0.040 on the validation set.Analyzing the confusion matrix depicted in figure 4(a) provides insight into understanding the predictions of the teacher model.The matrix shows misinterpretations of 14 images as malignant and three as benign.Moreover, correctly predicted images had a lower loss than the model's validation loss, and their prediction probability was nearly 100%.However, wrongly predicted cases had a loss greater than 0.7, and the probability scores are varied.
Similarly, during KD, the student model's accuracy and loss were recorded and displayed in figure 5.The student model achieved an accuracy of 98.71% and a loss of 0.05, slightly different from the teacher model.Moreover, the confusion matrix shown in figure 4(b) depicts that 16 images were misinterpreted as malignant, while four were misinterpreted as benign from the student model.Furthermore, we carried out a comparative analysis to evaluate the resemblance of the models for misclassified images.The results revealed that both models had similar misclassified images of nine benign and one malignant, indicating high similarity.
Apart from the accuracy, precision, F-score, recall, and MCC performance of both models, which are computed and displayed in table 3, their performance is also evaluated separately for each magnification factor of the BreakHis dataset.Table 4 displays the performance metrics of both the master and the student, expressed as accuracy, precision, F-score, recall, and MCC for four magnification factors.The results demonstrate that both models exhibit robustness to magnification factor, thereby avoiding the need to train the models separately for individual magnification factor.
For the WSI analysis, both models were utilized independently.The teacher model achieved a higher accuracy of 84% compared to the student model's 82%.However, the student model demonstrated faster processing times for the WSI, with a throughput of approximately 1.67 samples s −1 compared to the teacher model's 0.72 samples s −1 .Despite this, neither model could accurately predict patches where both malignant and benign conditions coexist, leading to lower accuracy compared to the analysis of the BreakHis dataset.This can be attributed to the use of softmax layers as the final layer in both the teacher and student models.Hence, these models rely on the highest probability scores to predict specific categories, leading to one outcome.Addressing the coexistence effect may involve applying a one-hot encoder to each category and training the model as a multi-label concept with a high threshold.The second approach involves training the model in a segmentation task where each tumor is delineated with a boundary.However, the latter is efficient but computationally intensive.

Comparing results with recent works
The student model is evaluated against more recent models in the literature for different magnification factors.Table 5 demonstrates the superior performance of the proposed model compared to recently suggested models.Furthermore, the same model exhibits versatility for various magnification factors.

Conclusion
This introduces teacher and student models build upon the Swin transformer.The teacher model undergoes training via transfer learning, while the student model is trained using KD alongside the SARSA algorithm, achieved accuracies of 98.91% and 98.71%, respectively.Moreover, the student model matched the master model's performance in just 50 episodes.It demonstrates the algorithm's effectiveness in enhancing model performance while minimizing computational overhead.Furthermore, a comparative study between the teacher and student models on WSI resulted in 84% and 82% accuracy, respectively.Notably, the student model exhibited 2.3 times greater throughput than the teacher model, suggesting it can attain similar accuracy with significantly reduced processing time.However, both models' accuracy decreased when analyzing WSI, possibly due to the coexistence of benign and malignant tumors within the same image after splitting the WSI.Future work could benefit from training the segmentation task model with masks containing both types of tumors exposed to various magnification factors.Enhancing the model's capacity to accurately identify benign and malignant regions, regardless of image resolution, ensures its effectiveness across diverse clinical scenarios.Finally, using KD not only made the model better and faster but also showed promise for applying the same approach to other medical data types.

Figure 1 .
Figure 1.Knowledge distillation with the SARSA algorithm involves splitting the dataset into training and validation.During KD, the teacher and student models are compared using KL divergence loss, and the student model is evaluated with cross-entropy loss.The total loss, computed by SARSA using the temperature T and the weighting factor α, guides the subsequent update of the student model through backpropagation with AdamW.SARSA updates the values of T and α based on training (loss and accuracy) and validation( loss and accuracy).

Figure 3 .
Figure 3. Displays the accuracy and loss of the teacher model throughout each epoch of training and validation.The dataset used includes breast cancer histopathology images from the BreakHis dataset.(A) Accuracy improves as the number of epochs increases and is almost constant in the range of 45-50.(B) Loss decreases as the number of epochs is increased.

Figure 4 .
Figure 4. Confusion matrices for the validation dataset are generated on teacher and student models.(A) For the teacher: three instances are wrongly predicted as benign, and 14 as malignant.(B) For the student: four instances are wrongly predicted as benign, and 16 as malignant.

Figure 5 .
Figure 5. Student model accuracy and loss are monitored during knowledge distillation.(A) Training and validation accuracy improve as the number of epochs progresses.(B) Both training and validation losses converge to almost the same value.
a)] 15: update current T and α values based on action a: T ← T + a T , α ← α + aα 16: clip T and α values to valid ranges [0, 10] and [0, 1], respectively 17: update state to s ′ ← (new T, new α) 18: choose new action a ′ from s ′ using ε-greedy policy derived from Q 19: set a T and aα based on SARSA lookup table Q(T, α) and exploration rate ε 20: decay exploration rate ε over time 21: end for 22: return T, α that maximize validation accuracy and minimize validation loss.

Table 1 .
Models analysisNote: this table presents a comparison between the Swin v1 and Swin v2 models in terms of various metrics.Sl.No. denotes the serial number of the parameter being compared.Teacher and student represent the two models being compared, with teacher being the more complex (Swin v1) and student being the simpler one (Swin v2).

Table 2 .
Number of images per category and magnification factor.

Table 3 .
Teacher and student model performance analysis.

Table 4 .
Models performance for four magnification factors.

Table 5 .
Evaluating the performance of the student model in comparison with recent methods.