In this paper, instead of finetuning the SAM model, we propose Med SAM Adapter, which integrates the medical specific domain knowledge to the segmentation model, by a simple yet effective adaptation technique.
The main reason for SAM’s failure over medical images is due to the lack of training data. Although SAM established a sophisticated and efficient data engine in the training, they collected few cases for medical usage. In this paper, we attempt to expand SAM towards prompt-based medical image segmentation with minimum effort. Technically, we choose to fine-tune the pre-trained SAM using a parameter-efficient fine-tuning (PEFT) technique called Adaption.
Preliminary: SAM architecture 这一部分先对SAM结构进行了回顾,详见论文以及相关文献
Instead of fully turning all parameters, we kept the pre-trained SAM parameters frozen and inserted an Adapter module at specific positions in the architecture.
In the SAM encoder, we have deployed two adapters for each ViT block. For a standard ViT block (as shown in Fig. 2 (a)), we have placed the first Adapter after the multi-head attention and before the residue connection (as shown in Fig. 2 (b)), and the second Adapter in the residue path of the MLP layer after the multi-head attention. Immediately after the second Adapter, we have scaled
In the SAM decoder, we have deployed three adapters for each ViT block. The first Adapter is deployed after the prompt-to-image embedding multi-head cross attention with a residue addition of the prompt embedding. We have made a small modification on the Adapter to integrate the prompt information into the module. Specifically, we have used another down projection to compress the prompt embedding, and added it into the Adapter over the embedding before ReLU activation. This modification helps the Adapter to turn the parameter conditioned on the prompt information and be more flexible and general to different modalities and downstream tasks. The second Adapter in the decoder is deployed in exactly the same way as in the encoder, to adapt the MLP enhanced embedding. The third Adapter is deployed after the residue connection of the image embedding-to-prompt cross-attention. Another residue connection and layer normalization are connected after the adaption to output the final results. Note that we have only deployed the Adapter in the first block of the two blocks of the Decoder. The second block and the mask prediction head are fully turned over the given data.
we propose a novel adaptation method inspired by the image-to-video adaptation, with some modifications. The specific architecture is shown in Fig. 2 (c). In each block, we split the attention operation into two branches: the space branch and the depth branch. For a given 3D sample with depth D, we send D x N x L to the multi-head attention in the space branch, where N is the number of embeddings, and L is the length of the embedding. Here, D is the number of operations, and the interaction is applied over N x L to learn and abstract space correlations as embeddings. In the depth branch, we first transpose the input matrix to obtain N x D x L and then send it to the same multi-head attention. Although we use the same attention mechanism, the interaction is applied over D x L. In this way, depth correlations are learned and abstracted. Finally, we transpose the results from the depth branch back to their original shape and add them to the results of the space branch
为解决医疗图像中三维图像连续层相关性的问题,我们提出了一种新的适应方法,启发图像到视频的适应,并进行了一些修改。具体架构如图2 (c).所示在每个区块中,我们将注意力操作分为两个分支:空间分支和深度分支。对于给定的深度为D的三维样本,我们将D x N x L发送给空间分支中的多头注意,其中N为嵌入数,L为嵌入的长度。这里,D是操作的次数,交互被应用于N x L来学习和抽象空间相关性作为嵌入。在深度分支中,我们首先将输入矩阵转置得到N x D x L,然后将其发送给同一多头注意。虽然我们使用相同的注意机制,但交互作用被应用在D x L上。通过这种方式,深度相关性被学习和抽象。最后,我们将来自深度分支的结果转置回它们原来的形状,并将它们添加到空间分支的结果中。
Training Strategy
训练策略
预训练:
Instead of the MAE pretraining used in SAM, we use a combination of several self-supervised learning methods for pre-training. The first two are Contrastive Embedding-Mixup (e-Mix) and Shuffled Embedding Prediction (ShED), following [32]. e-Mix is a contrastive objective that additively mixes a batch of original input embeddings, weighting them with different coefficients. It then trains an encoder to produce a vector for a mixed embedding that is close to the original inputs’ embeddings in proportion to their mixing coefficients. ShED shuffles a fraction of embeddings and trains the encoder with a classifier to predict which embeddings were perturbed. We also use a Masked Autoencoder (MAE) following the original implementation of SAM, which masks a given fraction of input embeddings and trains models to reconstruct them.
We use a different text prompt training strategy with SAM. In SAM, the authors used the image embedding of the target object crop produced by CLIP as the image embedding close to its corresponding text description or definition in CLIP. However, since CLIP is barely trained on medical image datasets, it can hardly relate the organs/lesions on the image with the corresponding text definition. Instead, we first randomly generate several free texts containing the
definition of the target (i.e., optic disc, brain tumor) as the keyword from ChatGPT, and then extract the embedding of the text using CLIP as the prompt for training. One free text could contain multiple targets, in which case we supervise the model with all their corresponding masks