Pytorch scatter()函数

Pytorch 中的scatter函数

scatter在数据准备时构建one-hot向量有很大的作用。可以简化构建方式。本文结构为:

  1. 一维one-hot向量构建
  2. 二维one-hot向量构建
  3. scatter函数原理

一维one-hot向量构建:

本文中所说的一维one-hot向量构建是指类似于在NLP中词向量的构建,或在离散分类任务中类别one-hot向量的构建。具体地:当输入label是对应的类别标签(这里假设是10个类别)我们的输入label是(当batch_size=4)$label = [[1],[2],[5],[9]]$ 我们需要构建batch_size个one-hot向量:

这时就可以使用scatter函数来进行构建:

1
2
3
4
5
6
batch_size = 4
class_num = 10
oneHot_size = (batch_size,class_num)
lable_oht = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
label = torch.tensor([[1],[2],[5],[9]])
lable_oht = label_oht.scatter_(1,label.long(),1)

通过上面的代码就可以得到one-hot向量,特别需要注意的就是我们的label的shape一定要是标准的(batch_size,1)

二维one-hot向量构建:

在instance分割时我遇到了二维one-hot的需求。具体地:当输入的是一张含有像素的类别的mask标签时,我们需要构建一个shape为(batch_size,class_num,H,W)的one-hot矩阵。每一个channel对应的是属于该类别的像素。这个应用很广泛,在cityscape数据集,对于图像中不同属性进行分类处理都是很重要的预处理操作。

1
2
3
4
5
batch_size, class_num, height, width = 1, 4, 2, 2
input_map= torch.tensor([[[[0,1],[0,3]]]]) # torch.Size([1, 1, 2, 2])
oneHot_size = (batch_size, class_num, height, width)
input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, input_map.long(), 1.0)

此时input_map为:

一共四个channle,每个channle对应一个类的pixel。上面的例子可以解读为,对于一个$2 \times 2$的mask图而言:其中0类的pixel在input_map的第一个channle中;1类的pixel在input_map的第二个channle中; 没有2类;3类的pixel在input_map的第四个channle中。

scatter函数原理:

参考文章