当前位置:   article > 正文

tensorflow one-hot独热编码_tensorflow的onehot编码

tensorflow的onehot编码

1 基本概念

解释下什么叫做独热编码(one-hot encoding),独热编码一般是在有监督学习中对数据集进行标注时候使用的,指的是在分类问题中,将存在数据类别的那一类用X表示,不存在的用Y表示,这里的X常常是1, Y常常是0。,举个例子: 
        比如我们有一个5类分类问题,我们有数据(x_{i},y_{i})(x_{i},y_{i}),其中类别y_{i}有五种取值(因为是五类分类问题),所以如果y_{j}为第一类那么其独热编码为: [1,0,0,0,0],如果是第二类那么独热编码为:[0,1,0,0,0],也就是说只对存在有该类别的数的位置上进行标记为1,其他皆为0。这个编码方式经常用于多分类问题,特别是损失函数为交叉熵函数的时候。接下来我们再介绍下TensorFlow中自带的对数据进行独热编码的函数tf.one_hot(),首先先贴出其API手册:

  1. one_hot(
  2. indices,
  3. depth,
  4. on_value=None,
  5. off_value=None,
  6. axis=None,
  7. dtype=None,
  8. name=None
  9. )

需要指定indices,和depth,其中depth是编码深度,on_valueoff_value相当于是编码后的开闭值,如同我们刚才描述的X值和Y值,需要和dtype相同类型(指定了dtype的情况下),axis指定编码的轴。这里给个小的实例:

2 一维one-hot实例

  1. import tensorflow as tf
  2. var1 = tf.one_hot(indices=[0,1,2,3],depth=6,axis=0)
  3. var2 = tf.one_hot(indices=[0,1,2,3],depth=6,axis=1)
  4. var3 = tf.one_hot(indices=[0,1,2,3],depth=4,axis=1)
  5. with tf.Session() as sess:
  6. sess.run(tf.global_variables_initializer())
  7. a,b,c = sess.run([var1,var2,var3])
  8. print(a)
  9. print(b)
  10. print(c)
  11. '''
  12. # depth为编码深度,它的大小取决于你的分类数目,axis是维度的扩展方向,取0按照第一维度扩展为6
  13. [[1. 0. 0. 0.]
  14. [0. 1. 0. 0.]
  15. [0. 0. 1. 0.]
  16. [0. 0. 0. 1.]
  17. [0. 0. 0. 0.]
  18. [0. 0. 0. 0.]]
  19. # 取1按照第二维度扩展为6
  20. [[1. 0. 0. 0. 0. 0.]
  21. [0. 1. 0. 0. 0. 0.]
  22. [0. 0. 1. 0. 0. 0.]
  23. [0. 0. 0. 1. 0. 0.]]
  24. # 下面是假定我们要对四个类别进行独热编码,需要在第二个维度进行扩展为4
  25. [[1. 0. 0. 0.]
  26. [0. 1. 0. 0.]
  27. [0. 0. 1. 0.]
  28. [0. 0. 0. 1.]]
  29. '''

3 二维矩阵的独热编码

语义分割的标签一般都是单通道的图片,标签类别都是从0开始编码,比如pasal voc数据集有20个类别,那么在分割数据集中的ground truth中是从0-20的数据,有时候为了计算损失函数我们需要将这个标签转化为独热编码的形式,然后与预测的结果计算损失,下面演示一个四分类的语义分割标签如何做独热编码。

  1. import tensorflow as tf
  2. segmentation_gt = [
  3. [0,0,1,1,1,0,0,0,0],
  4. [0,0,0,1,0,0,2,2,0],
  5. [0,3,0,0,0,0,2,2,0],
  6. [0,3,3,0,0,0,2,0,0],
  7. [3,3,0,0,2,0,0,0,0],
  8. [0,0,0,0,2,2,0,0,0]
  9. ]
  10. seg_onehot = tf.one_hot(indices=segmentation_gt,depth=4,axis=2)
  11. with tf.Session() as sess:
  12. sess.run(tf.global_variables_initializer())
  13. a = sess.run(seg_onehot)
  14. # 查看第三个通道的one-hot编码,我们发现在所有为2的位置,值变为1,其它地方的编码值为0
  15. print(a[...,2])
  16. '''
  17. [[0. 0. 0. 0. 0. 0. 0. 0. 0.]
  18. [0. 0. 0. 0. 0. 0. 1. 1. 0.]
  19. [0. 0. 0. 0. 0. 0. 1. 1. 0.]
  20. [0. 0. 0. 0. 0. 0. 1. 0. 0.]
  21. [0. 0. 0. 0. 1. 0. 0. 0. 0.]
  22. [0. 0. 0. 0. 1. 1. 0. 0. 0.]]
  23. '''

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Guff_9hys/article/detail/922681
推荐阅读
相关标签
  

闽ICP备14008679号