赞
踩
github来源
原文论文
最近在学习TransUnet的算法,希望能应用到降雨预测中
论文中本人认为没有对此结构进行非常清晰的解释,尤其是CNN代码块部分。而源码因本人刚开始接触深度学习内容阅读起来也较为困难。因此打算分成两个部分进行学习,首先是TransUnet的结构部分,重点为数据的输入输出。其次是源代码中自己使用数据的替换。暂时不对结构的深层次原理做过多研究。
首先非常感谢这位博主提供的TransUnet具体实现图,为本人理解此网络结构帮助巨大
https://blog.csdn.net/zjiafbaodaozmj/article/details/119063094?utm_medium=distribute.pc_aggpage_search_result.none-task-blog-2~aggregatepage~first_rank_ecpm_v1~rank_v31_ecpm-1-119063094.pc_agg_new_rank&utm_term=transunet%E8%AF%A6%E8%A7%A3&spm=1000.2123.3001.4430https://blog.csdn.net/zjiafbaodaozmj/article/details/119063094?utm_medium=distribute.pc_aggpage_search_result.none-task-blog-2~aggregatepage~first_rank_ecpm_v1~rank_v31_ecpm-1-119063094.pc_agg_new_rank&utm_term=transunet%E8%AF%A6%E8%A7%A3&spm=1000.2123.3001.4430TransUnet的网络形式仿照Unet,由encoder和decoder组成U型结构。
Encoder部分加入了Transformer机制,最终得到了一个一维向量。Decoder部分做了三次上采样,最终将此一维向量恢复成了原来的图像。Encoder和Decoder部分还做了三次跳跃连接。
下面对每个部分进行详细解读。
Encoder:
结构图中的CNN是最困惑笔者的部分。论文中解释这部分包含Image Sequentialization和Patch Embedding两个部分。
翻阅源码发现,Image Sequentialization被封装在了名为ResNetV2的class类中。类的root部分包括卷积、GroupNorm,和最大池化。body部分产生了3个block,用于与decoder对应大小进行跳跃连接。
Patch Embedding被封装在Embeddings的类中,通过卷积,flatten,position和dropout最终输出带有位置信息的一维向量。
一维向量再被输入Transformer块中重复12次,输出与原始大小相同的一维向量
Decoder:
Decoder部分的目的是将一维向量重新恢复成原来的图像,并与encoder进行跳跃连接。被封装在DecoderCup类中。
Reshape首先将H*W/P的图像恢复为H/P*W/P。再输入Conv2dReLU中,此代码块包括Conv2d,ReLU和BatchNormal。接下来进行3次DecoderBlock,具体过程见结构详细图解。最终再进行一次上采样就得到了恢复的图像。
最终输出的形式,以及如何将输入数据替换成自己的数据,笔者将在下一章讲解。
TransUnet结构详细图解(包含对应代码块)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。