赞
踩
本文证明了掩码自编码器(MAE)是一种可扩展的计算机视觉自监督学习算法。我们的MAE方法很简单:我们屏蔽输入图像的随机补丁并重建缺失的像素。它基于两个核心设计。首先,我们开发了一个非对称编码器-解码器架构,其中一个编码器仅对补丁的可见子集(没有掩码令牌)进行操作,以及一个轻量级解码器,该解码器从潜在表示和掩码令牌重建原始图像。其次,我们发现掩盖输入图像的高比例,例如75%,产生了一个重要的和有意义的自我监督任务。这两种设计的结合使我们能够高效地训练大型模型:我们加速了训练(3倍或更多)并提高了准确性。我们的可扩展方法允许学习泛化良好的大容量模型:例如,在仅使用ImageNet-1K数据的方法中,vanilla ViT-Huge模型达到了最好的准确率(87.8%)。下游任务的迁移性能优于监督预训练,并显示出有希望的缩放行为。
我们的MAE架构。在预训练过程中,图像补丁的一个大的随机子集(例如75%)被屏蔽掉。该编码器应用于可见补丁的小子集。在编码器之后引入掩码令牌,然后由一个小型解码器处理完整的编码补丁和掩码令牌,以像素为单位重建原始图像。预训练后,丢弃解码器,将编码器应用于未损坏的图像(完整的补丁集)进行识别任务。
conda create -n mae python=3.8 -y
conda activate mae
pip install timm==0.3.2 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
pip install tensorboard -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple
import torch
if torch.cuda.is_available():
print("GPU 可用")
else:
print("GPU 不可用")
解决方法
import torch
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
from torch._six import container_abcs
else:
import collections.abc as container_abcs
pip install --force-reinstall --no-deps numpy==1.23.5
pip install numpy==1.23.5 -i https://pypi.tuna.tsinghua.edu.cn/simple
(1)Evaluate ViT-Base in a single GPU (${IMAGENET_DIR} is a directory containing {train, val} sets of ImageNet):
python main_finetune.py --eval --resume mae_finetuned_vit_base.pth --model vit_base_patch16 --batch_size 16 --data_path ${IMAGENET_DIR}
python main_finetune.py --eval --resume models/mae_finetuned_vit_base.pth --model vit_base_patch16 --batch_size 16 --data_path dataset/imagenet_1k
(2)评估 ViT-Large:
python main_finetune.py --eval --resume mae_finetuned_vit_large.pth --model vit_large_patch16 --batch_size 16 --data_path ${IMAGENET_DIR}
python main_finetune.py --eval --resume models/mae_finetuned_vit_large.pth --model vit_large_patch16 --batch_size 16 --data_path dataset/imagenet_1k
(3)评估 ViT-Huge
python main_finetune.py --eval --resume mae_finetuned_vit_huge.pth --model vit_huge_patch14 --batch_size 16 --data_path ${IMAGENET_DIR}
python main_finetune.py --eval --resume models/mae_finetuned_vit_huge.pth --model vit_huge_patch14 --batch_size 16 --data_path dataset/imagenet_1k
(1)要使用多节点分布式训练进行微调,请在 4 个节点(每个节点有 8 个 GPU)上运行以下命令:
首先安装submit
pip install submitit -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
python submitit_finetune.py \
--job_dir ${JOB_DIR} \
--nodes 4 \
--batch_size 32 \
--model vit_base_patch16 \
--finetune ${PRETRAIN_CHKPT} \
--epochs 100 \
--blr 5e-4 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
--dist_eval --data_path ${IMAGENET_DIR}
这里,有效批量大小为 32(batch_size每个 GPU)* 4 ( nodes) * 8(每个节点的 GPU)= 1024。
blr是基础学习率。实际值通过线性缩放规则lr计算:lr = blr * effective batch size / 256。
我们使用不同的随机种子进行了 4 次试验。结果为 83.63、83.66、83.52、83.46(平均值 83.57,标准差 0.08)。
在 32 个 V100 GPU 上的训练时间约为 7 小时 11 分钟。
(2)ViT-Large 的脚本:
python submitit_finetune.py \
--job_dir ${JOB_DIR} \
--nodes 4 --use_volta32 \
--batch_size 32 \
--model vit_large_patch16 \
--finetune ${PRETRAIN_CHKPT} \
--epochs 50 \
--blr 1e-3 --layer_decay 0.75 \
--weight_decay 0.05 --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
--dist_eval --data_path ${IMAGENET_DIR}
我们使用不同的随机种子进行了 4 次试验。结果为 85.95、85.87、85.76、85.88(平均值 85.87,标准差 0.07)。
在 32 个 V100 GPU 上,训练时间约为 8 小时 52 分钟。
(3)ViT-Huge 的脚本:
python submitit_finetune.py \
--job_dir ${JOB_DIR} \
--nodes 8 --use_volta32 \
--batch_size 16 \
--model vit_huge_patch14 \
--finetune ${PRETRAIN_CHKPT} \
--epochs 50 \
--blr 1e-3 --layer_decay 0.75 \
--weight_decay 0.05 --drop_path 0.3 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
--dist_eval --data_path ${IMAGENET_DIR}
在 64 个 V100 GPU 中,训练时间约为 13 小时 9 分钟。
(1)要通过单节点训练微调我们预先训练的 ViT-Base ,请在具有 8 个 GPU 的 1 个节点上运行以下命令:
OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
--accum_iter 4 \
--batch_size 32 \
--model vit_base_patch16 \
--finetune ${PRETRAIN_CHKPT} \
--epochs 100 \
--blr 5e-4 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
--dist_eval --data_path ${IMAGENET_DIR}
这里有效的批量大小是 32 (batch_size每个 GPU) * 4 ( accum_iter) * 8 (GPU) = 1024。–accum_iter 4模拟 4 个节点
在4个节点上运行以下命令,每个节点8个gpu:
python submitit_linprobe.py \
--job_dir ${JOB_DIR} \
--nodes 4 \
--batch_size 512 \
--model vit_base_patch16 --cls_token \
--finetune ${PRETRAIN_CHKPT} \
--epochs 90 \
--blr 0.1 \
--weight_decay 0.0 \
--dist_eval --data_path ${IMAGENET_DIR}
要使用多节点分布式训练预训练 ViT-Large(推荐默认值),请在 8 个节点(每个节点有 8 个 GPU)上运行以下命令:
python submitit_pretrain.py \
--job_dir ${JOB_DIR} \
--nodes 8 \
--use_volta32 \
--batch_size 64 \
--model mae_vit_large_patch16 \
--norm_pix_loss \
--mask_ratio 0.75 \
--epochs 800 \
--warmup_epochs 40 \
--blr 1.5e-4 --weight_decay 0.05 \
--data_path ${IMAGENET_DIR}
要使用其他存储库的预训练模型,需要转换关键字。
我们在tools目录中提供了一个beit2mmseg.py脚本,用于将MAE模型的关键字从官方存储库转换为MMSegmentation样式。
python tools/model_converters/beit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
python tools/model_converters/beit2mmseg.py https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth pretrain/mae_pretrain_vit_base_mmcls.pth
python tools/model_converters/beit2mmseg.py https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth pretrain/mae_pretrain_vit_large_mmcls.pth
此脚本转换模型并将PRETRAIN_PATH转换后的模型存储在STORE_PATH.
在我们的默认设置中,预训练模型可以定义如下:
验证模型的单尺度结果:
sh tools/dist_test.sh \
work_dirs/runs/train/selfup/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py checkpoints/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth --eval mIoU
python tools/test.py --config work_dirs/runs/train/selfup/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py --checkpoint checkpoints/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth --eval mIoU
由于相对位置嵌入要求输入长宽相等,因此采用滑动窗口进行多尺度推理。所以我们设置min_size=512,即最短边为512。所以config的多尺度推理是单独进行的,而不是’–aug-test’。对于多尺度推理:
sh tools/dist_test.sh \
configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py \
upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth $GPUS --eval mIoU
# Copyright (c) OpenMMLab. All rights reserved.import math import math import torch import torch.nn as nn from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, trunc_normal_) from mmcv.runner import ModuleList, _load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm from mmseg.utils import get_root_logger from ..builder import BACKBONES from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer class MAEAttention(BEiTAttention): """Multi-head self-attention with relative position bias used in MAE. 具有相对位置偏差的多头自注意在MAE中的应用 This module is different from ``BEiTAttention`` by initializing the relative bias table with zeros. """ def init_weights(self): """Initialize relative position bias with zeros.""" # As MAE initializes relative position bias as zeros and this class # inherited from BEiT which initializes relative position bias # with `trunc_normal`, `init_weights` here does # nothing and just passes directly pass class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer): """Implements one encoder layer in Vision Transformer. This module is different from ``BEiTTransformerEncoderLayer`` by replacing ``BEiTAttention`` with ``MAEAttention``. """ def build_attn(self, attn_cfg): self.attn = MAEAttention(**attn_cfg) @BACKBONES.register_module() class MAE(BEiT): """VisionTransformer with support for patch. Args: img_size (int | tuple): Input image size. Default: 224. patch_size (int): The patch size. Default: 16. in_channels (int): Number of input channels. Default: 3. embed_dims (int): embedding dimension. Default: 768. num_layers (int): depth of transformer. Default: 12. num_heads (int): number of attention heads. Default: 12. mlp_ratio (int): ratio of mlp hidden dim to embedding dim. Default: 4.MLP隐藏维度与编码维度之比 out_indices (list | tuple | int): Output from which stages. Default: -1. attn_drop_rate (float): The drop out rate for attention layer. Default 0.0 drop_path_rate (float): stochastic depth rate. Default 0.0. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN') act_cfg (dict): The activation config for FFNs. Default: dict(type='GELU'). patch_norm (bool): Whether to add a norm in PatchEmbed Block. Default: False. final_norm (bool): Whether to add a additional layer to normalize final feature map. Default: False. num_fcs (int): The number of fully-connected layers for FFNs. Default: 2. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False. pretrained (str, optional): model pretrained path. Default: None. init_values (float): Initialize the values of Attention and FFN with learnable scaling. Defaults to 0.1. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. """ def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dims=768, num_layers=12, num_heads=12, mlp_ratio=4, out_indices=-1, attn_drop_rate=0., drop_path_rate=0., norm_cfg=dict(type='LN'), act_cfg=dict(type='GELU'), patch_norm=False, final_norm=False, num_fcs=2, norm_eval=False, pretrained=None, init_values=0.1, init_cfg=None): super(MAE, self).__init__( img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dims=embed_dims, num_layers=num_layers, num_heads=num_heads, mlp_ratio=mlp_ratio, out_indices=out_indices, qv_bias=False, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, norm_cfg=norm_cfg, act_cfg=act_cfg, patch_norm=patch_norm, final_norm=final_norm, num_fcs=num_fcs, norm_eval=norm_eval, pretrained=pretrained, init_values=init_values, init_cfg=init_cfg) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) self.num_patches = self.patch_shape[0] * self.patch_shape[1] self.pos_embed = nn.Parameter( torch.zeros(1, self.num_patches + 1, embed_dims)) def _build_layers(self): dpr = [ x.item() for x in torch.linspace(0, self.drop_path_rate, self.num_layers) ] self.layers = ModuleList() for i in range(self.num_layers): self.layers.append( MAETransformerEncoderLayer( embed_dims=self.embed_dims, num_heads=self.num_heads, feedforward_channels=self.mlp_ratio * self.embed_dims, attn_drop_rate=self.attn_drop_rate, drop_path_rate=dpr[i], num_fcs=self.num_fcs, bias=True, act_cfg=self.act_cfg, norm_cfg=self.norm_cfg, window_size=self.patch_shape, init_values=self.init_values)) def fix_init_weight(self): """Rescale the initialization according to layer id. This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501 Copyright (c) Microsoft Corporation Licensed under the MIT License """ def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.layers): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.ffn.layers[1].weight.data, layer_id + 1) def init_weights(self): def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) self.apply(_init_weights) self.fix_init_weight() if (isinstance(self.init_cfg, dict) and self.init_cfg.get('type') == 'Pretrained'): logger = get_root_logger() checkpoint = _load_checkpoint( self.init_cfg['checkpoint'], logger=logger, map_location='cpu') state_dict = self.resize_rel_pos_embed(checkpoint) state_dict = self.resize_abs_pos_embed(state_dict) self.load_state_dict(state_dict, False) elif self.init_cfg is not None: super(MAE, self).init_weights() else: # We only implement the 'jax_impl' initialization implemented at # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 # Copyright 2019 Ross Wightman # Licensed under the Apache License, Version 2.0 (the "License") trunc_normal_(self.cls_token, std=.02) for n, m in self.named_modules(): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if m.bias is not None: if 'ffn' in n: nn.init.normal_(m.bias, mean=0., std=1e-6) else: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): kaiming_init(m, mode='fan_in', bias=0.) elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): constant_init(m, val=1.0, bias=0.) def resize_abs_pos_embed(self, state_dict): if 'pos_embed' in state_dict: pos_embed_checkpoint = state_dict['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches # height (== width) for the checkpoint position embedding orig_size = int( (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) # height (== width) for the new position embedding new_size = int(self.num_patches**0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute( 0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) state_dict['pos_embed'] = new_pos_embed return state_dict def forward(self, inputs): B = inputs.shape[0] x, hw_shape = self.patch_embed(inputs) # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed outs = [] for i, layer in enumerate(self.layers): x = layer(x) if i == len(self.layers) - 1: if self.final_norm: x = self.norm1(x) if i in self.out_indices: out = x[:, 1:] B, _, C = out.shape out = out.reshape(B, hw_shape[0], hw_shape[1], C).permute(0, 3, 1, 2).contiguous() outs.append(out) return tuple(outs)
norm_cfg = dict(type='SyncBN', requires_grad=True) model = dict( type='EncoderDecoder', #type:要构建的模型的类型 pretrained='work_dirs/pretrain/mae/mae_pretrain_vit_base_mmcls.pth',#预训练模型的位置 backbone=dict( type='MAE',#骨干模块的类型 img_size=(512, 512), patch_size=16, in_channels=3, embed_dims=768, num_layers=12, num_heads=12, mlp_ratio=4, out_indices=[3, 5, 7, 11], attn_drop_rate=0.0, drop_path_rate=0.1, norm_cfg=dict(type='LN', eps=1e-06), act_cfg=dict(type='GELU'), norm_eval=False, init_values=1.0), neck=dict(type='Feature2Pyramid', embed_dim=768, rescales=[4, 2, 1, 0.5]),#颈部结构连接ViT主干和解码器头。 decode_head=dict( type='UPerHead', in_channels=[768, 768, 768, 768], in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=768, dropout_ratio=0.1, num_classes=150, norm_cfg=dict(type='SyncBN', requires_grad=True), align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=768, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=150, norm_cfg=dict(type='SyncBN', requires_grad=True), align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), train_cfg=dict(), test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(341, 341))) dataset_type = 'ADE20KDataset' data_root = 'data/ade/ADEChallengeData2016' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) crop_size = (512, 512) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict( type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True), dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']) ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2048, 512), flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict( type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']) ]) ] data = dict( samples_per_gpu=2, workers_per_gpu=4, train=dict( type='ADE20KDataset', data_root='data/ade/ADEChallengeData2016', img_dir='images/training', ann_dir='annotations/training', pipeline=[ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict( type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True), dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']) ]), val=dict( type='ADE20KDataset', data_root='data/ade/ADEChallengeData2016', img_dir='images/validation', ann_dir='annotations/validation', pipeline=[ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2048, 512), flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict( type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']) ]) ]), test=dict( type='ADE20KDataset', data_root='data/ade/ADEChallengeData2016', img_dir='images/validation', ann_dir='annotations/validation', pipeline=[ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2048, 512), flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict( type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']) ]) ])) log_config = dict( interval=50, hooks=[ dict(type='TextLoggerHook', by_epoch=False), dict(type='TensorboardLoggerHook') ]) dist_params = dict(backend='nccl') log_level = 'INFO' load_from = None resume_from = None workflow = [('train', 1)] cudnn_benchmark = True optimizer = dict( type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05, constructor='LayerDecayOptimizerConstructor', paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65)) optimizer_config = dict() lr_config = dict( policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-06, power=1.0, min_lr=0.0, by_epoch=False) runner = dict(type='IterBasedRunner', max_iters=160000) checkpoint_config = dict(by_epoch=False, interval=16000) evaluation = dict( interval=16000, metric=['mIoU', 'mFscore'], pre_eval=True, save_best='auto') fp16 = dict(loss_scale='dynamic') work_dir = 'work_dirs/runs/train/selfup/mae' gpu_ids = [0] auto_resume = False
def mae_vit_large_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
位置编码
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。