首先生成一个主对角全为1的其余全为0的矩阵,比如有n个分类就是n * n,效果如下:
随后根据标签列表(或者numpy数组)选取合适的行,比如标签是[9, 1, 0, 0], 那么就会选择上图矩阵中对应的9、1、0、0行,得到one_hot标签,如果不熟悉numpy数组的列表切片的(就是说numpy_array[slice]中的slice是列表) ,可以看下这篇Python Numpy数组使用列表索引
恢复的话就是找列表中为1的下标即可。
代码如下:
# encoding = utf-8
'''
author : James-J
time : 2019/05/29
'''
import numpy as np
if __name__ == '__main__':
one_hot = np.eye(10) # 10*10的矩阵 对角线上是1
print('np.eye(10)\n', one_hot)
# 两种方法 传一维的numpy数组和列表都可以
label = np.array([1, 4, 8, 9, 5, 0])
one_hot_label = one_hot[label.astype(np.int32)] # 表示选取矩阵上面的第几行
# label = [1, 4, 8, 9, 5, 0]
# one_hot_label = one_hot[label]
print('-----------------one_hot--------------------')
print(one_hot_label)
label = [one_label.tolist().index(1) for one_label in one_hot_label] # 找到下标是1的位置
print('------------------label---------------------')
print(label)
得到的结果: