CTA-UNet: CNN-transformer architecture UNet for dental CBCT images segmentation

In view of the limitations of current deep learning models in segmenting dental cone-beam computed tomography (CBCT) images, specifically dealing with complex root morphological features, fuzzy boundaries between tooth roots and alveolar bone, and the need for costly annotation of dental CBCT images. We collected dental CBCT data from 200 patients and annotated 45 of them for network training, and proposed a CNN-Transformer Architecture UNet network, which combines the advantages of CNN and Transformer. The CNN component effectively extracts local features, while the Transformer captures long-range dependencies. Multiple spatial attention modules were included to enhance the network’s ability to extract and represent spatial information. Additionally, we introduced a novel Masked image modeling method to pre-train the CNN and Transformer modules simultaneously, mitigating limitations due to a smaller amount of labeled training data. Experimental results demonstrate that the proposed method achieved superior performance (DSC of 87.12%, IoU of 78.90%, HD95 of 0.525 mm, ASSD of 0.199 mm), and provides a more efficient and effective approach to automatically and accurately segment dental CBCT images, has real-world applicability in orthodontics and dental implants.


Introduction
In the field of dentistry, the accurate tooth segmentation in CBCT images and the construction of threedimensional tooth models are critical for diagnosis and treatment planning, and the technology is extensively applied in the stomatology clinical practices, including implant surgery, oral and maxillofacial surgery, orthodontics, periodontics, endodontics, and treatment for temporomandibular disorders. The accurate tooth segmentation in CBCT images is an essential factor on the construction of 3D tooth models. However, due to reasons such as complex morphology of tooth roots and blurred boundary between tooth roots and alveolar bone, it is difficult to segment the tooth area from CBCT scans, while manual marking requires dental professionals, which is time consuming and costly. Therefore, an automatic and efficient tooth segmentation on CBCT is needed desperately.
The traditional method for automated CBCT tooth segmentation in dentistry often uses the level set method (Gan et al 2015, Zichun et al 2020, which relies on initial point selection performed manually for tooth segmentation, and is difficult to achieve fully automated tooth segmentation, with poor segmentation results of roots and alveolar bone. With continuously advancing in deep learning technology, convolutional neural networks (CNNs) are increasingly used in dental CBCT (Cui et  Any further distribution of this work must maintain attribution to the author(s) and the title of the work, journal citation and DOI. segmentation on CBCT imaging, a CNN-based multistage fully automated tooth segmentation network was proposed by Lee et al (2020). A lightweight automated tooth segmentation network was proposed using dilated convolution and residual connection in Ma and Yang (2019). A two-stage method for precise segmentation of teeth and pulp cavities was proposed by Duan et al using feature pyramid networks and region proposal networks to extract bounding boxes for individual teeth (Duan et al 2021). Despite the above advantages of CNNs, they are limited in capturing full contextual feature information as they only extract features in local regions.
In recent years, the Transformer (Dosovitskiy et al 2010, Liu et al 2021a has been widely used in various visual tasks by capturing global contextual information through attention to create remote dependencies on the target and extract more effective features. In the field of medical image segmentation, Valanarasu et al proposed a Gated Axial-Attention model and constructed the MedT network (Valanarasu et al 2021), which expands the existing architectures by introducing an additional control mechanism in the self-attention module to achieve segmentation from medical images. Ji et al proposed a MCTrans network, which effectively establishes dependencies of scales and correlations of features to perform accurate medical image segmentation (Ji et al 2021). Although Transformers have been gradually applied to various medical image segmentation tasks, the use of large-scale datasets is necessary to train network weights, and the limited availability of small medical image datasets makes it difficult to meet training requirements, thereby limiting the model's performance.
In this paper, we propose a CNN-Transformer Architecture UNet (CTA-UNet) that combines the advantages of CNNs and Transformers through a parallel architecture. This approach integrates local features extracted by convolutional operators and global representations obtained by self-attention modules to enhance the network's performance.
The contributions of this work are: (1) We collected CBCT scans from 200 patients and annotated 45 volumes to create a dataset of dental CBCT images. The dataset is available by contacting the corresponding author upon reasonable request.
(2) We proposed a U-shaped CTA-UNet with CNN-Transformer architecture that is parallel, enhancing the network's ability to extract local features and capture long-range dependencies, and effectively merging multi-scale features.
(3) We proposed a masked image modeling (MIM) method CTAMIM for networks with CNN-Transformer architecture, which enables weight parameters of convolution operators to be trained more effectively by delaying the positioning of masks during training.
(4) We used CTA-UNet pre-trained with CTAMIM to achieve a Dice Score of 87.12% in dental CBCT image segmentation, outperforming the existing models. Although CNNs have made progress in medical image segmentation tasks, there are obvious drawbacks in existing pure CNN architectures. CNNs cannot capture global feature information, and even with denser downsampling to expand the receptive field, their capture range is still limited while lowering the resolution.

Vision transformers
Recently, Transformers have been used in a variety of visual tasks, represented by ViT (Dosovitskiy et al 2010) and Swin Transformer (Liu et al 2021a). ViT segments images into 16 16 patches and embeds location information to construct a token sequence, which is fed into Transformer Encoder to extract global contextual features using multi-head self-attention (MSA). Swin Transformer divides fixed-size sampling blocks into different-sized windows according to hierarchy and calculates self-attention only within each window to improve computing efficiency and reduce computing overhead. Furthermore, Swin Transformer uses Patch Merging to implement an operation similar to downsampling, enabling Transformers to expand their receptive field like CNNs and improve network performance.
TransUNet (Chen et al 2021) and Swin-Unet (Cao et al 2023) are typical networks that introduce Transformers into medical image segmentation. Both of them reference the U-shaped design of U-Net and use skip connections to connect the encoder and decoder. TransUNet adopts a CNN-Transformer architecture, where the encoder first uses CNNs to extract low-level features and downsampling, then sends the extracted feature maps to Transformer Layers. This network combines the strengths of CNNs and Transformers and achieves higher performance than U-Net, Attention U-Net, ViT, and other networks in various segmentation tasks. Swin-Unet uses a pure Transformer architecture, replacing all convolutional blocks in U-Net with swin transformer blocks and using patch merging and patch expanding to achieve downsampling and upsampling. In multiple tasks, Swin-Unet achieves segmentation accuracy that surpasses TransUNet.
Transformers can reflex the global representation's complex spatial transformations and long-distance feature dependencies. However, Transformers are not as good as CNNs in local feature extraction, and a small amount of training data will limit the performance of Transformers (Liu et al 2021b). SimMIM (Xie et al 2022) is similar to MAE as shown in figure1(b), the main difference between the two is that SimMIM directly replace patches with trainable mask tokens and feeds them into the encoder, which makes it unnecessary for the output shape to be consistent with the input shape from the encoder. Therefore, SimMIM can be used in networks like Swin Transformer. But neither MAE nor SimMIM can be applied to CNNs.

Masked image modeling
The first method applying MIM to CNNs was A2MIM (Li et al 2022a), as shown in figure 2.
In an early stage, A2MIM replaces part of patches with RGB mean to minimize the local statistical variation caused by masks and add the mask token directly to the corresponding patch at one of the stage/layer of encoder, thus enhancing the middle-order interactions of the network. A2MIM is general to CNNs and Transformers, as it adapts to feature extraction from low-level to semantic level. Although A2MIM introduces MIM to CNNs, in early stages of the network, large-scale random values or mean values are not conducive to the learning of CNNs since the spatch size is usually set to 16 16 and the convolution operator size is only 3 3 as shown in figure 3, that may make convolutional hard to train.

Methodologies
3.1. Network structure Our proposed network is comprised of an encoder and a decoder and features a U-shaped structure with skip connections, as shown in figure 4. The encoder consists of five stages, with the first two composed of two sets of Conv Blocks. These blocks extract texture information from the image and reduce the size of the feature map. By combining convolution with Global Max Pooling (GMP), the network can extract features within a given range using a single channel of information, and downsize them to half of the original size, achieving a similar effect to Patch Embedding, the implementation details of Conv Block as shown in figure 5(a). Additionally, Xiao et al discovered that replacing the patchify stem with a convolutional stem can expedite model convergence and improve network performance (Xiao et al 2021).
From the third stage, features are extracted using parallel Conv Blocks and Trans Blocks. Each Trans Block consists of patch embedding, swin transformer blocks, unpatchify operation, and batch normalization, as shown in figure 5 Patch embedding uses 1 1 patch size to ensure the same receptive field size for both Conv Blocks and Trans Blocks, which facilitates the fusion of features. Swin transformer blocks use a moving window for  self-attention computations within the region to reduce computational effort, and finally use batch normalization to align feature representations of CNNs and Transformers.
For the decoder, we use Multi-Spatial Attention Block (MSAB) to process feature information from the Encoder. As shown in figure 6.  MSAB is based on spatial attention mechanism, which weights spatial dimensional information of features. The module uses GMP and GAP (Global Average Pooling) and 1 1 Conv to extract spatial features respectively and captures concerned regions. Since the feature information of encoder and decoder is concatenated, and the features are fused by 1 1 Conv, which can be approximated as channel attention, we do not add additional channel attention module to reduce memory consumption and training time. And finally, we upsampled different scale features using nearest neighbor interpolation and concat them, to maximize the fusion of multi-scale feature.
3.2. Self-supervised pre-train As shown in figure 7, CTAMIM is different from traditional MIM methods. Traditional MIM methods use masking strategy to divided the image into non-overlapping patches. MAE hides some patches and feeds them to encoder, for SimMIM and A2MIM, patches are replaced with randomly initialized learnable mask tokens or RGB mean values and feeds to encoder. Therefore, existing MIM methods must mask the image at the beginning of the network, which leads to the inclusion of large-sized random or mean blocks in the image, may affect normal adjustment of convolution operator weights, feature extraction capacity of the network may be limited. CTAMIM defers the position of masks and feeds the input image data directly to the encoder, where the primary feature information is extracted through Conv Stem of the network, which can be any convolutional module with downsampling. As low-level feature information is extracted through Conv Stem, the features are divided Take CTA-UNet as an example, set patch size as 16 16 and Conv Stem in CTA-UNet as stages 1-2. There are two downsampling processes in Conv Stem, with primary size of input image of 1 512 512´and size of output image of 64 128 128.´Now, the size of receptive field is 4 4, and the feature image is divided into a set of non-overlapping patches with 4 4 model patch size, i.e., corresponding to patches divided with patch size of 16 16.
The above methods improve the surface layer feature extraction capability of the network. The masked token sequence is fed into the remaining Encoder, which can be any CNNs, Transformers or CNN-Transformers. For decoder, SimMIM verifies that good training results can also be obtained using the lean decoder, CTAMIM uses a linear layer to convert feature size to that of the original image. Similarly, CTA-MIM uses L1 Loss to calculate the prediction level of the masked region.

Dataset and experiment details
Dataset. The dataset used in this paper are from the retrospective CBCT data collected by the cooperative dental hospital in the past two years, and there are several volumes of metal artifacts in the data set. The dataset consists of a total of 200 patients (120 kV, 5 mA) with a voxel size of 0.3 mm. It is important to note that the image data does not contain any patient identifiers and the work does not breach any ethical considerations.
We randomly selected and annotated 45 volumes from the dataset, each volume containing 440 2D slices. All data were used in the pre-training tasks, and the annotated data were used in fine-tune and direct training of the model. In terms of dataset splitting, the annotated 45 volumes were randomly divided into non-repetitive training, validation, and test sets in the ratio of 6:2:2.
Experiment details. The experiment used four NVIDIA A100 GPUs for training. All models used the same input image size of 1 512 512,´and data enhancement strategy only used random horizontal flipping with a probability of 0.5. For regularization of input image applications, the mean and standard deviation were obtained statistically from all data with a mean of 0.1765 and a standard deviation of 0.1739.
We used an AdamW optimizer for 100 epochs, with a linear warm-up for the first 10 epochs, set initial learning rate as 0.001, and weight decay as 0.05. We used the cosine learning rate scheduler to facilitate convergence of the networks by adjusting the learning rate during training. During the model training process, we validated and compared the segmentation results of each epoch on the validation set, selected the weight with the best performance as the result of the training, and tested the final segmentation performance of the model on the test set. We also used the gradient accumulation methods and set batch size to 2048. For all models, the loss function used for pretraining is L1 loss, and the loss function used for fine-tuning and direct training for all models is a mixed loss of Dice loss and Focal loss, as shown in equation (1): where y is the ground truth,ŷ is the predicted value, N is the total number of predicted pixels. Set g in Focal Loss as 4, and a as 0.5. In this paper, We used dice similarity coefficient (DSC), intersection over union (IoU), 95% Hausdorff Distance (HD95), and Average Symmetric Surface Distance (ASSD) is used as the evaluation metrics of segmentation results.

Comparison with existing models
To test the effectiveness of the network, we first trained the network directly, and compared the existing representative medical image segmentation networks, of which U-Net and Attention U-Net are CNNs, Swin-UNet is a network of Transformers, TransUNet is a network of CNN-Transformers. The segmentation results of the above models are shown in table 1, and the visualization results are shown in figure 8. With the similar training parameters, the performance of CTA-UNet is higher than other models, UNet and Attention UNet also achieve high accuracy, and the accuracy of Swin-UNet is relatively low, which may be attributed to the small data volume that debases Transformers' training effect, while CNNs show good generalization performance with small dataset. Although TransUNet is CNN-Transformers network, its performance is limited by connecting the two in series. While CTA-UNet's parallel connection combines the advantages of CNNs and Transformers for better performance.

Pre-train and fine-tune
The pre-training of the models were implemented on OpenMixup (Li et al 2022b). The MIM methods for comparison are MAE, SimMIM and A2MIM. Since MAE and SinMIM cannot be applied in CNNs, we chose classical Transformer backbones of ViT and Swin to test the effect of different MIM pre-training methods.
In this paper, we used a patch size of 16 16 and a random mask ratio of 50%, where ResNet-50 and CTA Encoder pre-trained by A2MIM referred to the source code, substituting 50% patches with mean masks in the early stage and add mask tokens in the middle layer of the network to increase the mid-order interaction of the network. While CTAMIM did not use masks in the early stage and replaces randomly initialized mask tokens in the middle layer of the network. The visualization results of the models are shown in figure 9: To ensure the fairness of the results, for the above backbones, we added a similar decoder for fine-tune. The design of decoder was as simple as possible to reduce its influence on network performance, to compare the pretraining effect. The structure of decoder is shown in figure 10.  The decoder consists of five sets of convolutions, note that for Transformers, unpatchify patches are required before feeding into the decoder. For CNNs and CNN-Transformers, it upsamples before the sets of convolutions to restore the features' original sizes. After the decoder, a 1 1 convolution is used to segment the image.
Fine-tune has two stages. In the first 90 epochs, we froze encoder's weight and only trained decoder, and in the last 10 epochs, we adjusted the parameters of the whole network. The results of fine-tune are shown in table 2. The CTA-Encoder pre-trained by our proposed MIM method has the best effect in Fine-tune, followed by CTA-Encoder pre-trained by A2MIM. The results demonstrate that the CTAMIM method proposed in this  paper can be better applied to CNN-Transformers for self-supervised pre-training. The pre-trained networks present superior performance in dental CBCT image segmentation tasks. The visualization of the segmentation results of the above models are shown in figure 11.
In addition, we observed that the segmentation accuracy of ResNet-50 network pre-trained by A2MIM was not as expected. It was presumably because of the loss of some feature information through continuous downsampling and upsampling in CNNs, which made it difficult to restore details in the back part of the model. However, there are more skip connections in Transformers, which improve the feature representation. Therefore, we added skip connections to the decoder and fine-tuned ResNet-50 and CTA-Encoder backbone once again, the results are as shown in table 3.
After adding skip connections to the decoder, the fine-tuned segmentation results of the model are greatly improved, which supports our speculation, and shows that the parallel structure of CNN-Transformers in the CTA-Encoder has better performance in extraction and fusion of different sized features. Finally, we use the CTA-encoder pre-trained by CTAMIM, build the full CTA-UNet network and fine-tune it, The DSC (Dice  Similarity Coefficient) of the segmentation results reached 87.12%, and the proposed method also exhibited superior performance compared to the comparative method in three other metrics. The segmentation results of the above models are shown in figure 12.

Generalization experiment
To assess the generalization performance of the proposed method, we evaluated the trained model on an external dataset introduced by Cui et al (2022). The dataset comprises data collected by devices with varying parameters, featuring four different voxel sizes of { } 0.2, 0.3, 0.4, 1 . However, due to the low sampling resolution of the data with a voxel size of 1, resulting in an image size of only 200 200, which significantly differs from the input size of the model, we have chosen to include only the remaining 100 volumes in our dataset. To ensure consistency, all image data was uniformly upsampled to a size of 512 512.
The handling of the dataset was otherwise consistent with the process described in section 4.1. The experimental results are presented in table 4.
The results in table 4 shown that the trained model exhibits certain segmentation ability for the new dataset, with the pretrained CTA-UNet achieving 74.83% DSC. This suggests that the model has incorporated some general knowledge from the previous training. However, due to differences in data distribution, its accuracy on the external dataset is not as high as that on the dataset we constructed. To adapt the model to the distribution of the external dataset and achieve higher segmentation accuracy, we froze the encoder part of the aforementioned model and performed finetuning on the external dataset for just one epoch. The segmentation results are presented in table 4. Remarkably, the model quickly adapts to the distribution of the external dataset with just  one epoch, yielding high accuracy. Moreover, the pretrained CTA-UNet delivers the highest segmentation accuracy in the experiment, indicating its strong generalization ability and transfer effect.

Ablation studies
The results of the ablation experiment are shown in table 5. In experiments 1-3, we pre-trained the CTA-Encoder with different MIM methods and added a decoder for fine-tuning. The results show that the fine-tuned CTA-Encoder using CTAMIM gets a higher segmentation accuracy than using A2MIM and without MIMs, indicating that deferring the position of masks in training is conducive to training convolutional weight parameters. In experiments 1 and 4, we examined the performance of CTA-Decoder, and the results show that segmentation accuracy is improved greatly by substituting simple decoder with skip connections with CTA-Decoder, which implies that CTA-Decoder can better merge different sized features. In experiments 4 and 5, We applied all the enhancements and fine-tuned the pre-trained CTA-UNet model. The resulting segmentation results were superior to those of the comparison model in terms of DSC, IoU, HD95 and ASSD metrics.

Conclusions
In this paper, we study the limitations of traditional CNNs and Transformers in the dental CBCT image segmentation task, explore the feasibility of introducing Transformer into the task lacking in the amount of annotation data, and focus on how to alleviate the limitations of small amount of data on the model. We collected CBCT data of 200 patients and annotated 45 volumes to create a new dataset for dental CBCT image segmentation tasks. We proposed a CTA-UNet, which is a U-shaped segmentation network with CNN-Transformer architecture. The network has a parallel structure of CNNs and Transformers, and uses convolutional operators to extract local features, uses self-attention modules to capture long-range dependencies, and uses a multi-spatial attention block and multi-scale feature fusion head to fuse extracted features. In addition, we proposed a novel MIM method CTAMIM suitable for CNN-Transformer architecture. The method alleviates the limitation of the model caused by the lack of annotated data and helps to train the weight of CNN operator effectively. Experimental results show that the CTA-UNet pre-trained by CTAMIM outperforms traditional models in dental CBCT image segmentation tasks, which is of practical significance in orthodontics and teeth implants.

Data availability statement
The data cannot be made publicly available upon publication because no suitable repository exists for hosting data in this field of study. The data that support the findings of this study are available upon reasonable request from the authors.