当前位置:   article > 正文

【深度学习】使用tensorflow实现VGG19网络_tensorflow训练vgg19

tensorflow训练vgg19

深度学习】使用tensorflow实现VGG19网络

 

 

 

本文章向大家介绍【深度学习】使用tensorflow实现VGG19网络,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

 

 

 

 

VGG网络与AlexNet类似,也是一种CNN,VGG在2014年的 ILSVRC localization and classification 两个问题上分别取得了第一名和第二名。VGG网络非常深,通常有16-19层,卷积核大小为 3 x 3,16和19层的区别主要在于后面三个卷积部分卷积层的数量。第二个用tensorflow独立完成的小玩意儿......

 

 

模型结构

可以看到VGG的前几层为卷积和maxpool的交替,每个卷积包含多个卷积层,后面紧跟三个全连接层。激活函数采用Relu,训练采用了dropout,但并没有像AlexNet一样采用LRN(论文给出的理由是加LRN实验效果不好)。

模型定义

  1. def maxPoolLayer(x, kHeight, kWidth, strideX, strideY, name, padding = "SAME"):
  2. """max-pooling"""
  3. return tf.nn.max_pool(x, ksize = [1, kHeight, kWidth, 1],
  4. strides = [1, strideX, strideY, 1], padding = padding, name = name)
  5. def dropout(x, keepPro, name = None):
  6. """dropout"""
  7. return tf.nn.dropout(x, keepPro, name)
  8. def fcLayer(x, inputD, outputD, reluFlag, name):
  9. """fully-connect"""
  10. with tf.variable_scope(name) as scope:
  11. w = tf.get_variable("w", shape = [inputD, outputD], dtype = "float")
  12. b = tf.get_variable("b", [outputD], dtype = "float")
  13. out = tf.nn.xw_plus_b(x, w, b, name = scope.name)
  14. if reluFlag:
  15. return tf.nn.relu(out)
  16. else:
  17. return out
  18. def convLayer(x, kHeight, kWidth, strideX, strideY,
  19. featureNum, name, padding = "SAME"):
  20. """convlutional"""
  21. channel = int(x.get_shape()[-1]) #获取channel数
  22. with tf.variable_scope(name) as scope:
  23. w = tf.get_variable("w", shape = [kHeight, kWidth, channel, featureNum])
  24. b = tf.get_variable("b", shape = [featureNum])
  25. featureMap = tf.nn.conv2d(x, w, strides = [1, strideY, strideX, 1], padding = padding)
  26. out = tf.nn.bias_add(featureMap, b)
  27. return tf.nn.relu(tf.reshape(out, featureMap.get_shape().as_list()), name = scope.name)

定义了卷积、pooling、dropout、全连接五个模块,使用了上一篇AlexNet中的代码,其中卷积模块去除了group参数,因为网络没有像AlexNet一样分成两部分。接下来定义VGG19

  1. class VGG19(object):
  2. """VGG model"""
  3. def __init__(self, x, keepPro, classNum, skip, modelPath = "vgg19.npy"):
  4. self.X = x
  5. self.KEEPPRO = keepPro
  6. self.CLASSNUM = classNum
  7. self.SKIP = skip
  8. self.MODELPATH = modelPath
  9. #build CNN
  10. self.buildCNN()
  11. def buildCNN(self):
  12. """build model"""
  13. conv1_1 = convLayer(self.X, 3, 3, 1, 1, 64, "conv1_1" )
  14. conv1_2 = convLayer(conv1_1, 3, 3, 1, 1, 64, "conv1_2")
  15. pool1 = maxPoolLayer(conv1_2, 2, 2, 2, 2, "pool1")
  16. conv2_1 = convLayer(pool1, 3, 3, 1, 1, 128, "conv2_1")
  17. conv2_2 = convLayer(conv2_1, 3, 3, 1, 1, 128, "conv2_2")
  18. pool2 = maxPoolLayer(conv2_2, 2, 2, 2, 2, "pool2")
  19. conv3_1 = convLayer(pool2, 3, 3, 1, 1, 256, "conv3_1")
  20. conv3_2 = convLayer(conv3_1, 3, 3, 1, 1, 256, "conv3_2")
  21. conv3_3 = convLayer(conv3_2, 3, 3, 1, 1, 256, "conv3_3")
  22. conv3_4 = convLayer(conv3_3, 3, 3, 1, 1, 256, "conv3_4")
  23. pool3 = maxPoolLayer(conv3_4, 2, 2, 2, 2, "pool3")
  24. conv4_1 = convLayer(pool3, 3, 3, 1, 1, 512, "conv4_1")
  25. conv4_2 = convLayer(conv4_1, 3, 3, 1, 1, 512, "conv4_2")
  26. conv4_3 = convLayer(conv4_2, 3, 3, 1, 1, 512, "conv4_3")
  27. conv4_4 = convLayer(conv4_3, 3, 3, 1, 1, 512, "conv4_4")
  28. pool4 = maxPoolLayer(conv4_4, 2, 2, 2, 2, "pool4")
  29. conv5_1 = convLayer(pool4, 3, 3, 1, 1, 512, "conv5_1")
  30. conv5_2 = convLayer(conv5_1, 3, 3, 1, 1, 512, "conv5_2")
  31. conv5_3 = convLayer(conv5_2, 3, 3, 1, 1, 512, "conv5_3")
  32. conv5_4 = convLayer(conv5_3, 3, 3, 1, 1, 512, "conv5_4")
  33. pool5 = maxPoolLayer(conv5_4, 2, 2, 2, 2, "pool5")
  34. fcIn = tf.reshape(pool5, [-1, 7*7*512])
  35. fc6 = fcLayer(fcIn, 7*7*512, 4096, True, "fc6")
  36. dropout1 = dropout(fc6, self.KEEPPRO)
  37. fc7 = fcLayer(dropout1, 4096, 4096, True, "fc7")
  38. dropout2 = dropout(fc7, self.KEEPPRO)
  39. self.fc8 = fcLayer(dropout2, 4096, self.CLASSNUM, True, "fc8")
  40. def loadModel(self, sess):
  41. """load model"""
  42. wDict = np.load(self.MODELPATH, encoding = "bytes").item()
  43. #for layers in model
  44. for name in wDict:
  45. if name not in self.SKIP:
  46. with tf.variable_scope(name, reuse = True):
  47. for p in wDict[name]:
  48. if len(p.shape) == 1:
  49. #bias 只有一维
  50. sess.run(tf.get_variable('b', trainable = False).assign(p))
  51. else:
  52. #weights
  53. sess.run(tf.get_variable('w', trainable = False).assign(p))

buildCNN函数完全按照VGG的结构搭建网络。

loadModel函数从模型文件中读取参数,采用的模型文件见github上的readme说明。 至此,我们定义了完整的模型,下面开始测试模型。

模型测试

ImageNet训练的VGG有很多类,几乎包含所有常见的物体,因此我们随便从网上找几张图片测试。比如我直接用了之前做项目的图片,为了避免审美疲劳,我们不只用渣土车,还要用挖掘机、采沙船:

然后编写测试代码:

  1. parser = argparse.ArgumentParser(description='Classify some images.')
  2. parser.add_argument('mode', choices=['folder', 'url'], default='folder')
  3. parser.add_argument('path', help='Specify a path [e.g. testModel]')
  4. args = parser.parse_args(sys.argv[1:])
  5. if args.mode == 'folder': #测试方式为本地文件夹
  6. #get testImage
  7. withPath = lambda f: '{}/{}'.format(args.path,f)
  8. testImg = dict((f,cv2.imread(withPath(f))) for f in os.listdir(args.path) if os.path.isfile(withPath(f)))
  9. elif args.mode == 'url': #测试方式为URL
  10. def url2img(url): #获取URL图像
  11. '''url to image'''
  12. resp = urllib.request.urlopen(url)
  13. image = np.asarray(bytearray(resp.read()), dtype="uint8")
  14. image = cv2.imdecode(image, cv2.IMREAD_COLOR)
  15. return image
  16. testImg = {args.path:url2img(args.path)}
  17. if testImg.values():
  18. #some params
  19. dropoutPro = 1
  20. classNum = 1000
  21. skip = []
  22. imgMean = np.array([104, 117, 124], np.float)
  23. x = tf.placeholder("float", [1, 224, 224, 3])
  24. model = vgg19.VGG19(x, dropoutPro, classNum, skip)
  25. score = model.fc8
  26. softmax = tf.nn.softmax(score)
  27. with tf.Session() as sess:
  28. sess.run(tf.global_variables_initializer())
  29. model.loadModel(sess) #加载模型
  30. for key,img in testImg.items():
  31. #img preprocess
  32. resized = cv2.resize(img.astype(np.float), (224, 224)) - imgMean #去均值
  33. maxx = np.argmax(sess.run(softmax, feed_dict = {x: resized.reshape((1, 224, 224, 3))})) #网络输入为224*224
  34. res = caffe_classes.class_names[maxx]
  35. font = cv2.FONT_HERSHEY_SIMPLEX
  36. cv2.putText(img, res, (int(img.shape[0]/3), int(img.shape[1]/3)), font, 1, (0, 255, 0), 2) #在图像上绘制结果
  37. print("{}: {}n----".format(key,res)) #输出测试结果
  38. cv2.imshow("demo", img)
  39. cv2.waitKey(0)

如果你看完了我AlexNet的博客,那么一定会发现我这里的测试代码做了一些小的修改,增加了URL测试的功能,可以测试网上的图像 ,测试结果如下:

 

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号