当前位置:   article > 正文

Unet的Pytorch实现_unet pytorch

unet pytorch

最近在开发一个基于Unet的剪枝模型,于是从论文到代码把Unet撸了一遍。本篇是基于Pytorch的Unet开源实现,复现Kaggle上的一个算法竞赛“ Carvana Image Masking Challenge”。

源码地址:https://github.com/milesial/Pytorch-UNet

原始论文地址:U-Net: Convolutional Networks for Biomedical Image Segmentation

1. 网络结构

Unet的提出最开始是为了解决医学图像分割的问题。

Unet的网络结构如下图所示:

从图中可以看出, Unet包含两条路径,左边的为收缩路径(Contracting path),右边的为扩张路径(Expansive path)。

收缩路径遵循典型的卷积神经网络结构,包括两个重复的3x3卷积(no padding),每个卷积后面跟着一个ReLU和一个步长为2的2x2 max pooling,以达到下采样的目的。每个下采样步骤中,将特征通道数增加一倍。

在扩张路径上,每一个步骤包含一个对feature map的上采样,然后是一个2x2的up-convolution,使得通道数减半;接下来是copy and crop,即把收缩路径中相同层的feature map经过裁剪之后拼接在当前层(由于左侧路径中的feature map比右侧对应路径中的feature map要大一些,因此需要crop之后才能做拼接),接着是两个3x3卷积+ReLU。在最后一个层,使用1x1卷积将64元素的特征向量映射到不同的类别。

Unet共有23个卷积层。

2. 开源实现

篇幅限制,本篇先不涉及具体代码,只讲工程操作。

2.1 运行容器

由于服务器上已经安装了比较新版的Docker以及nvidia-docker2,因此直接用Docker运行。通过如下命令运行容器:

sudo docker run --rm --shm-size=8g --ulimit memlock=-1 --gpus all -it milesial/unet

第一次执行以上命令时,由于本地没有 milesial/unet这个docker image,因此会自动从docker hub下载,时间略长一些。另外,如果不想命令行终端关闭的时候容器退出,可以把容器改为后台运行,最好指定一个容器名称,方便后续容器操作,命令行如下:

sudo docker run --rm --shm-size=8g --ulimit memlock=-1 --gpus all --name=unet -itd milesial/unet

2.2 下载数据集

容器运行起来之后,如果需要在容器中进行模型训练,我么需要去下载相应地数据集。本工程中,下载的是Kaggle上的竞赛数据集carvana-image-masking-challenge。执行工程中的脚本进行下载:

bash scripts/download_data.sh

脚本内容如下:

  1. #!/bin/bash
  2. if [[ ! -f ~/.kaggle/kaggle.json ]]; then
  3. echo -n "Kaggle username: "
  4. read USERNAME
  5. echo
  6. echo -n "Kaggle API key: "
  7. read APIKEY
  8. mkdir -p ~/.kaggle
  9. echo "{\"username\":\"$USERNAME\",\"key\":\"$APIKEY\"}" > ~/.kaggle/kaggle.json
  10. chmod 600 ~/.kaggle/kaggle.json
  11. fi
  12. pip install kaggle --upgrade
  13. kaggle competitions download -c carvana-image-masking-challenge -f train_hq.zip
  14. unzip train_hq.zip
  15. mv train_hq/* data/imgs/
  16. rm -d train_hq
  17. rm train_hq.zip
  18. kaggle competitions download -c carvana-image-masking-challenge -f train_masks.zip
  19. unzip train_masks.zip
  20. mv train_masks/* data/masks/
  21. rm -d train_masks
  22. rm train_masks.zip

其中的步骤也可以分开执行。需要注意的是,下载数据需要注册Kaggle账号并拿到API Key,然后需要在Kaggle竞赛界面进行验证和授权,这个我在上一篇博客中有详细介绍:

Ubuntu从Kaggle上下载数据集出现403 - Forbidden

2.3 模型训练

运行训练程序:

python train.py --amp

 正常情况下,执行以上语句后,即开始训练了:

想要修改训练参数,可以参考以下参数说明:

  1. > python train.py -h
  2. usage: train.py [-h] [--epochs E] [--batch-size B] [--learning-rate LR]
  3. [--load LOAD] [--scale SCALE] [--validation VAL] [--amp]
  4. Train the UNet on images and target masks
  5. optional arguments:
  6. -h, --help show this help message and exit
  7. --epochs E, -e E Number of epochs
  8. --batch-size B, -b B Batch size
  9. --learning-rate LR, -l LR
  10. Learning rate
  11. --load LOAD, -f LOAD Load model from a .pth file
  12. --scale SCALE, -s SCALE
  13. Downscaling factor of the images
  14. --validation VAL, -v VAL
  15. Percent of the data that is used as validation (0-100)
  16. --amp Use mixed precision

训练完成后,默认会在checkpoints路径下保存每个epoch的中间模型。

2.4 预测 

可以用训练出来的模型进行预测。预测默认使用的模型为“MODEL.pth”,可以把上一步保存的模型重命名为“MODEL.pth”,也可以通过-m选项指定一个模型。

预测的参数设置可参考如下说明:

  1. > python predict.py -h
  2. usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...]
  3. [--output INPUT [INPUT ...]] [--viz] [--no-save]
  4. [--mask-threshold MASK_THRESHOLD] [--scale SCALE]
  5. Predict masks from input images
  6. optional arguments:
  7. -h, --help show this help message and exit
  8. --model FILE, -m FILE
  9. Specify the file in which the model is stored
  10. --input INPUT [INPUT ...], -i INPUT [INPUT ...]
  11. Filenames of input images
  12. --output INPUT [INPUT ...], -o INPUT [INPUT ...]
  13. Filenames of output images
  14. --viz, -v Visualize the images as they are processed
  15. --no-save, -n Do not save the output masks
  16. --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
  17. Minimum probability value to consider a mask pixel white
  18. --scale SCALE, -s SCALE
  19. Scale factor for the input images

在这里,我们就按照默认的模型名称来做测试: 

python predict.py -i test_img.jpg -o output.jpg

 我们来看一下预测结果:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/87911
推荐阅读
相关标签
  

闽ICP备14008679号