应用场景
已知有“猫、狗、兔”三个类别(这里我们先固定这三个的顺序不变),假设你有一张图片,它的类别是狗,也就是第二个类别,那么,我们可以用一个向量来表示这个信息: [0 1 0]。其中这个向量的元素值都表示概率值,里面的第 i 个元素,对应于之前约定好的第 i 个类别,而且向量里的元素取值只能为0或1概率值,0表示不可能(0%)属于这个类别,1表示100%属于这个类别,而且约定这个向量只有一个元素为1,其余元素为0,不难得出其实向量所有元素之和为1,也就是概率和为1。所以 [0 1 0] ,只有第二个元素为1,其余元素都为0,表示这张图片类别是狗。这就是 one-hot 编码的运用过程。
为什么要用One-hot编码?
对于分类任务,对label使用one-hot编码,实际上是将数据离散特征扩展到欧式空间,而离散特征的取值就对应空间中的一个点,在回归、分类、聚类等算法中,距离计算或相似度计算是必要的,而我们常用的距离和相似度都是在欧式空间计算的。将离散特征独热编码是为了计算距离更为合理。
例如:预测的label是苹果,雪梨,香蕉,草莓这四个,显然他们不直接构成比较关系,但如果我们用1,2,3,4来做label就会出现了比较关系,labe之间的距离也不同:第一个label 和最后一个label的距离太远,影响模型的学习。模型觉得label 1和label 2最像,label 1和最后一个label最不像。此时:C0(第0类)错分成C3(第3类)的损失是C0错分成C1的损失的4倍!
不过当你的label之间存在直接的比较关系,就可以直接用数字当label。例如你做一个风控模型,预测的是四个风险类别[低,中,高,紧急],也可以用1,2,3,4来做label,因为类别之间的比较关系和数字之间的距离关系可以看成一样的。
若Loss是交叉熵损失,用于分类问题,则PyTorch下不需要做One-hot编码;若使用MSE,用于回归问题,则需要手动做One-hot编码。
代码:
>>> class_num = 10 >>> batch_size = 4 # label must be LongTensor, from 0 to class_num - 1. >>> label = torch.empty(batch_size, dtype=torch.long).random_(class_num) >>> print(label) tensor([0, 2, 9, 3]) >>> print(label.size()) torch.Size([4]) >>> label = label.unsqueeze(1) >>> print(label.size()) torch.Size([4, 1]) # The 2nd way to generate random label. >>> label_2 = torch.LongTensor(batch_size, 1).random_() % class_num # % class_num 代表对随机数取余数 >>> print(label_2) tensor([[8], [8], [2], [9]]) one_hot = torch.zeros(batch_size, class_num, dtype=torch.long).scatter_(1, label, 1) >>> print(one_hot) tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]]) # The easiest way to generate One-hot: # When Pytorch >= 1.1, use torch.nn.functional.one_hot import torch import torch.nn.functional as F one_hot = F.one_hot(label_2, num_classes=10) >>> print(onehot) tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]], [[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]]) >>> print(onehot.size()) torch.Size([4, 1, 10]) >>> onehot = onehot.squeeze() >>> print(onehot.size()) torch.Size([4, 10]) >>> print(onehot) tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])