Pytorch 中的scatter函数
scatter
在数据准备时构建one-hot向量有很大的作用。可以简化构建方式。本文结构为:
- 一维one-hot向量构建
- 二维one-hot向量构建
- scatter函数原理
一维one-hot向量构建:
本文中所说的一维one-hot向量构建是指类似于在NLP中词向量的构建,或在离散分类任务中类别one-hot向量的构建。具体地:当输入label是对应的类别标签(这里假设是10个类别)我们的输入label是(当batch_size=4
)$label = [[1],[2],[5],[9]]$ 我们需要构建batch_size
个one-hot向量:
这时就可以使用scatter
函数来进行构建:1
2
3
4
5
6batch_size = 4
class_num = 10
oneHot_size = (batch_size,class_num)
lable_oht = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
label = torch.tensor([[1],[2],[5],[9]])
lable_oht = label_oht.scatter_(1,label.long(),1)
通过上面的代码就可以得到one-hot
向量,特别需要注意的就是我们的label的shape
一定要是标准的(batch_size,1)
。
二维one-hot向量构建:
在instance分割时我遇到了二维one-hot的需求。具体地:当输入的是一张含有像素的类别的mask标签时,我们需要构建一个shape为(batch_size,class_num,H,W)
的one-hot矩阵。每一个channel
对应的是属于该类别的像素。这个应用很广泛,在cityscape数据集,对于图像中不同属性进行分类处理都是很重要的预处理操作。1
2
3
4
5batch_size, class_num, height, width = 1, 4, 2, 2
input_map= torch.tensor([[[[0,1],[0,3]]]]) # torch.Size([1, 1, 2, 2])
oneHot_size = (batch_size, class_num, height, width)
input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, input_map.long(), 1.0)
此时input_map
为:
一共四个channle,每个channle对应一个类的pixel。上面的例子可以解读为,对于一个$2 \times 2$的mask图而言:其中0类的pixel在input_map
的第一个channle中;1类的pixel在input_map
的第二个channle中; 没有2类;3类的pixel在input_map
的第四个channle中。