当前位置:   article > 正文

机器学习: 利用 Tensorflow 和预训练模型提取特征-- Mobilenet V1_tensorflow 获取模型特征列

tensorflow 获取模型特征列

传统的 CV 问题,一般把特征提取和分类模型的构建训练分成两个步骤,CNN 可以把这两者合在一个网络里,目前很多实验证明,利用大量数据训练过的 CNN 可以用作很好的特征提取器,类似一种特征迁移。

今天介绍一下,如何利用 Tensorflow 和 预先训练好的模型,做特征提取,我们可以用 TensorFlow GitHub 官网上的预训练模型来做特征提取:

https://github.com/tensorflow/models/tree/master/research/slim

预训练模型,是用 ImageNet 训练过的,网站上有 VGG, ResNet, 以及 Inception 等几种不同类似的训练模型:

今天我们利用一个轻量级的模型 Mobilenet_v1 来做特征提取,首先下载好训练好的模型:mobilenet_v1_1.0_224 ckpt

利用 ckpt 我们还可以查看整个网络的结构,以及每一层的 feature map

首先我们载入相应的模块:

import tensorflow as tf
import numpy as np
import glob
from nets import mobilenet_v1
slim = tf.contrib.slim

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

然后定义一个函数: 这个函数可以帮我们解析图片路径,读取图片,做预处理,然后转成 tensor 形式:

def mobi_parse_fun(x_in, y_label=1):
    img_path = tf.read_file(x_in)
    img_decode = tf.io.decode_jpeg(img_path, channels=3)
    img = tf.image.resize_images(img_decode, [224, 224])
    img = tf.cast(img, tf.float32) / 127.5 - 1.0
    
    return img, y_label
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

接下来,我们可以利用 TensorFlow 中的 dataset 模块,处理数据:

X_in = tf.placeholder(tf.string, None)
# Y_in = tf.placeholder(tf.int32, None)
train_data = tf.data.Dataset.from_tensor_slices((X_in))
train_data = train_data.map(mobi_parse_fun)
train_data = train_data.batch(1)

iter_ = tf.data.Iterator.from_structure(train_data.output_types,
                                        train_data.output_shapes)
x_batch, y_batch = iter_.get_next()
train_init_op = iter_.make_initializer(train_data)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

然后调用网络的定义,并且加载模型所在的路径:

with tf.contrib.slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
    logits, endpoints = mobilenet_v1.mobilenet_v1(x_batch, num_classes=1001)

ckpt_path = 'D:\Python_Code\mobilenet_v1_1.0_224\mobilenet_v1_1.0_224.ckpt'
saver = tf.train.Saver()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

我们获取图片的存储路径

img_path = 'F:\cute\*.jpg'
img_list = glob.glob(img_path)
  • 1
  • 2

接下来,就可以定义一个 session,并且把模型加载进来:

with tf.Session() as sess:

    saver.restore(sess, ckpt_path)
    
    ## 查看网络每一层的参数:
    print('print the trainable parameters: ')
    for eval_ in tf.trainable_variables():
        print(eval_.name)
        w_val = sess.run(eval_.name)
        print(w_val.shape)
    
    sess.run(train_init_op, feed_dict={X_in: img_list})
    
    #---------------------------------------------
    #---------------------------------------------
    # 查看每一层的 feature map,
    key_name = endpoints.keys()
    print('print the feature maps: ')
    for name_ in key_name:
        print(name_)
        feat_map = sess.run(endpoints[name_])
        print(feat_map.shape)

   
    fc_map = endpoints['AvgPool_1a']
    fc_feat = tf.squeeze(fc_map, [1, 2])
    
    for img_name in img_list:
        print(img_name)
        
        x_bat, y_bat = sess.run([x_batch, y_batch])
        print(x_bat.shape, y_bat.shape)
        
        fc_feature = sess.run([fc_feat])
        print(fc_feature[0].shape)
        
        break
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

我们可以查看 Mobinet_V1 的网络结构如下:

MobilenetV1/Conv2d_0/weights:0 (3, 3, 3, 32)
MobilenetV1/Conv2d_0/BatchNorm/gamma:0 (32,)
MobilenetV1/Conv2d_0/BatchNorm/beta:0 (32,)
MobilenetV1/Conv2d_1_depthwise/depthwise_weights:0 (3, 3, 32, 1)
MobilenetV1/Conv2d_1_depthwise/BatchNorm/gamma:0 (32,)
MobilenetV1/Conv2d_1_depthwise/BatchNorm/beta:0 (32,)
MobilenetV1/Conv2d_1_pointwise/weights:0 (1, 1, 32, 64)
MobilenetV1/Conv2d_1_pointwise/BatchNorm/gamma:0 (64,)
MobilenetV1/Conv2d_1_pointwise/BatchNorm/beta:0 (64,)
MobilenetV1/Conv2d_2_depthwise/depthwise_weights:0 (3, 3, 64, 1)
MobilenetV1/Conv2d_2_depthwise/BatchNorm/gamma:0 (64,)
MobilenetV1/Conv2d_2_depthwise/BatchNorm/beta:0 (64,)
MobilenetV1/Conv2d_2_pointwise/weights:0 (1, 1, 64, 128)
MobilenetV1/Conv2d_2_pointwise/BatchNorm/gamma:0 (128,)
MobilenetV1/Conv2d_2_pointwise/BatchNorm/beta:0 (128,)
MobilenetV1/Conv2d_3_depthwise/depthwise_weights:0 (3, 3, 128, 1)
MobilenetV1/Conv2d_3_depthwise/BatchNorm/gamma:0 (128,)
MobilenetV1/Conv2d_3_depthwise/BatchNorm/beta:0 (128,)
MobilenetV1/Conv2d_3_pointwise/weights:0 (1, 1, 128, 128)
MobilenetV1/Conv2d_3_pointwise/BatchNorm/gamma:0 (128,)
MobilenetV1/Conv2d_3_pointwise/BatchNorm/beta:0 (128,)
MobilenetV1/Conv2d_4_depthwise/depthwise_weights:0 (3, 3, 128, 1)
MobilenetV1/Conv2d_4_depthwise/BatchNorm/gamma:0 (128,)
MobilenetV1/Conv2d_4_depthwise/BatchNorm/beta:0 (128,)
MobilenetV1/Conv2d_4_pointwise/weights:0 (1, 1, 128, 256)
MobilenetV1/Conv2d_4_pointwise/BatchNorm/gamma:0 (256,)
MobilenetV1/Conv2d_4_pointwise/BatchNorm/beta:0 (256,)
MobilenetV1/Conv2d_5_depthwise/depthwise_weights:0 (3, 3, 256, 1)
MobilenetV1/Conv2d_5_depthwise/BatchNorm/gamma:0 (256,)
MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta:0 (256,)
MobilenetV1/Conv2d_5_pointwise/weights:0 (1, 1, 256, 256)
MobilenetV1/Conv2d_5_pointwise/BatchNorm/gamma:0 (256,)
MobilenetV1/Conv2d_5_pointwise/BatchNorm/beta:0 (256,)
MobilenetV1/Conv2d_6_depthwise/depthwise_weights:0 (3, 3, 256, 1)
MobilenetV1/Conv2d_6_depthwise/BatchNorm/gamma:0 (256,)
MobilenetV1/Conv2d_6_depthwise/BatchNorm/beta:0 (256,)
MobilenetV1/Conv2d_6_pointwise/weights:0 (1, 1, 256, 512)
MobilenetV1/Conv2d_6_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_6_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_7_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_7_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_7_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_7_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_7_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_8_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_8_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_8_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_8_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_8_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_9_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_9_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_10_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_10_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_10_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_10_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_11_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_11_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_11_pointwise/weights:0 (1, 1, 512, 512)
MobilenetV1/Conv2d_11_pointwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_11_pointwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_12_depthwise/depthwise_weights:0 (3, 3, 512, 1)
MobilenetV1/Conv2d_12_depthwise/BatchNorm/gamma:0 (512,)
MobilenetV1/Conv2d_12_depthwise/BatchNorm/beta:0 (512,)
MobilenetV1/Conv2d_12_pointwise/weights:0 (1, 1, 512, 1024)
MobilenetV1/Conv2d_12_pointwise/BatchNorm/gamma:0 (1024,)
MobilenetV1/Conv2d_12_pointwise/BatchNorm/beta:0 (1024,)
MobilenetV1/Conv2d_13_depthwise/depthwise_weights:0 (3, 3, 1024, 1)
MobilenetV1/Conv2d_13_depthwise/BatchNorm/gamma:0 (1024,)
MobilenetV1/Conv2d_13_depthwise/BatchNorm/beta:0 (1024,)
MobilenetV1/Conv2d_13_pointwise/weights:0 (1, 1, 1024, 1024)
MobilenetV1/Conv2d_13_pointwise/BatchNorm/gamma:0 (1024,)
MobilenetV1/Conv2d_13_pointwise/BatchNorm/beta:0 (1024,)
MobilenetV1/Logits/Conv2d_1c_1x1/weights:0 (1, 1, 1024, 1001)
MobilenetV1/Logits/Conv2d_1c_1x1/biases:0 (1001,)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83

而网络中的 feature map如下:

print the feature maps:
Conv2d_0 (1, 112, 112, 32)
Conv2d_1_depthwise (1, 112, 112, 32)
Conv2d_1_pointwise (1, 112, 112, 64)
Conv2d_2_depthwise (1, 56, 56, 64)
Conv2d_2_pointwise (1, 56, 56, 128)
Conv2d_3_depthwise (1, 56, 56, 128)
Conv2d_3_pointwise (1, 56, 56, 128)
Conv2d_4_depthwise (1, 28, 28, 128)
Conv2d_4_pointwise (1, 28, 28, 256)
Conv2d_5_depthwise (1, 28, 28, 256)
Conv2d_5_pointwise (1, 28, 28, 256)
Conv2d_6_depthwise (1, 14, 14, 256)
Conv2d_6_pointwise (1, 14, 14, 512)
Conv2d_7_depthwise (1, 14, 14, 512)
Conv2d_7_pointwise (1, 14, 14, 512)
Conv2d_8_depthwise (1, 14, 14, 512)
Conv2d_8_pointwise (1, 14, 14, 512)
Conv2d_9_depthwise (1, 14, 14, 512)
Conv2d_9_pointwise (1, 14, 14, 512)
Conv2d_10_depthwise (1, 14, 14, 512)
Conv2d_10_pointwise (1, 14, 14, 512)
Conv2d_11_depthwise (1, 14, 14, 512)
Conv2d_11_pointwise (1, 14, 14, 512)
Conv2d_12_depthwise (1, 7, 7, 512)
Conv2d_12_pointwise (1, 7, 7, 1024)
Conv2d_13_depthwise (1, 7, 7, 1024)
Conv2d_13_pointwise (1, 7, 7, 1024)
AvgPool_1a (1, 1, 1, 1024)
Logits (1, 1001)
Predictions (1, 1001)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

我们看到,最靠近 FC 的是 AvgPool_1a 这层的 feature map,所以我们将这层的 feature map抽取出来,就可以当成我们输入图像的特征来用了。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/875470
推荐阅读
相关标签
  

闽ICP备14008679号