导言:入门Pytorch,Mnist的训练!当然这篇文章肯定不会只是教你进行一个Mnist训练(这种网上已经够多了),我会将里面每一个使用的函数、或者Pytoch中的工具进行稍微的深入讲解。
训练你的MNIST模型:
先给出一段训练源码:这段源码是不完整地,一些完整地源码可以在网上找到,在这里只是对一些函数分开进行讲解。
1 | import torch |
损失函数
cost = torch.nn.CrossEntropyLoss()
softmax
损失- 定义cost对象,之后计算只用传入数据即可
- 需要注意的是
loss = cost(y_pred,var_y)
输入时有顺序的先写预测,再写label
cost = torch.nn.MSELoss()
欧式损失
优化算法
optimizer = torch.optim.Adam(model.parameters())
- 定义opt对象,可采用自定义优化算法,需要将model的参数全部传入,以算法进行计算。
- 需要注意的是:
optimizer.zero_grad()
,每次更新w
前,需要将上一次迭代的梯度清理。
关于data_loader_train
、data
的说明:
data_loader_train
是整个数据集data_loader_train
一次生成batch_size
个数据
shape
= (64,1,28,28)
所有第二个for是整个batch_size
的矩阵计算
Model的forward()
自定义Model继承了Module后需要重写父类的 forward()函数
即计算整个计算图。y_pred = model(var_x)
需要注意的是:
model(var_x)
的输入参数必须是Variable
类型,需要转换。
torch.max()
排列得到最优解,
torch.max(data,dim)
与python
中argmax(data,dim)
有着差不多的功能,其返回值有2个,一个是获得最大值的值,以及其索引。
1
2
3
4
5
6a = torch.tensor([[1,3,2],[4,1,3]])
_,pred = torch.max(a,0)
print(_)
#[4,3,3]
print(pred)
#[1,0,1]类似与python中这种带维度的比较技巧
torch.max(a,dim)
dim 维度 依次递增,其他维度不变,遍历得到各个数据之间相互比较得到样例
1
2
3
4
5torch.max(a,0)
a = [
[1,3,2],
[3,2,1]
]
保持第1维度不变,0维度遍历。
a[0][0] a[1][0] a[2][0] 比较一次 输出一次结果
GPU训练MNIST
数据的载入:
Mnist的读入有专门的接口函数封装在torchvision.Datasets中可以直接使用,这对新手很友好。
datasets
分多种,在MNIST
中使用的是1
2
3
4
5
6data_train = datasets.MNIST(
root = "./data/",
transform = transform_,
train = True, #是训练集
download = False #使用本地的训练集
)
在读入了新的数据集后,使用DataLoader进行数据的提取,具体的Dataloader的源码分析可以参考我的这篇文章
使用torch.utils.data.DataLoader()1
2
3
4
5data_loader_train = torch.utils.data.DataLoader(
batch_size = 64,
dataset = data_train,
shuffle = True
)
batch_size
指定了数据中每次读入的数据量。dataset
指定了数据集data_train
是torchvision.datasets
类。
从Dataloader中读出数据:
1 | x,y = next(iter(dataloader["train"])) |
将dataloader
强制转为迭代器,然后用next
读出一个batch
注意读出的数据是tensor
形式,有时需要转为numpy
的格式,使用tensor.numpy()
即可。而从numpy
构造tensor
可以 torch.from_numpy(array).float()
,需要指定数据格式,torch
中需要float
的数据。
Transforms
包名torchvision.transforms
,在数据集的读入时使用,目的时将数据集进行各种变化,数据加强等操作。1
2
3
4
5
6transforms.Compose(
[
transform.ToTensor(),
transform.Nomalize(mean = .5,std = .5)
]
)
总结下:
Pytorch读入数据时需要的步骤:
- 数据集的准备
- 数据装载
对于数据集的准备主要有以下几种方式:
使用已经封装好的函数
比如之前说的Mnist的载入。
使用文件夹的分类创建数据集
下面再举一个图像数据集的例子,使用文件夹的分类创建数据集:raw data
:可以是文件夹中的图片:直接用dataset.
torchvision.datasets:中
torchvision.datasets.DatasetFolder
(暂无信息。)torchvision.datasets.ImageFolder
1
dset.ImageFolder(root="root folder path", [transform, target_transform])
torchvision.datasets
是继承自torch.utils.data.Dataset
torch.utils.data.TensorDataset(data_tensor, target_tensor)
这让函数可以将tensor数据转为数据集
使用numpy中的数据构造数据集
新建两个np.array
类型的数据转为tensor
再转为dataset
将np->转为Tensor: torch.from_numpy(train_x).float()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40#tain data
train_x = np.linspace(-6, 8, self.N_SAMPEL)[:,np.newaxis]
# print(type(self.x))
self.bias = 5
self.noisy = np.random.normal(0,2,train_x.shape)
train_y = np.power(train_x,2) + self.bias + self.noisy
# print(self.x.shape)
# self.plot_data(train_x,train_y)
# test data
test_x = np.linspace(-7, 10, 200)[:, np.newaxis]
noise = np.random.normal(0, 2, test_x.shape)
test_y = np.square(test_x) - 5 + noise
# self.plot_data(test_x,test_y)
# plt.plot(train_x,train_y,'o',color= "blue")
# plt.plot(test_x,test_y,'+',color = "red")
# plt.show()
#装载到dataloader里面:
train_x = torch.from_numpy(train_x).float()
train_y = torch.from_numpy(train_y).float()
test_x = torch.from_numpy(test_x).float()
test_y = torch.from_numpy(test_y).float()
data = {
"train":
{
"x":train_x,
"y":train_y
},
"test":
{
"x":test_x,
"y":test_y
}
}
dataset = {x: Data.dataset.TensorDataset(data[x]["x"],data[x]["y"]) for x in ["train","test"] }
dataloader = {x : Data.DataLoader(dataset = dataset[x], batch_size = self.BATCH_SIZE, shuffle= True) for x in ["train","test"] }
使用GPU处理数据:
- model = Model()
model.cuda() 开启GPU训练模式 变量的转换:
1
2
3for data_t in data_loader_train:
x,y = data_t
var_x,var_y = Variable(x).cuda(),Variable(y).cuda()cpu->gpu:
variable(x).cuda()
gpu->cpu:
cuda_var(x).cpu()
加载以及存储模型权值
1 | class Model(torch.nn.Module): |
加载模型
model.load_state_dict(torch.load("checkPoints/LLModel_0708_21:54:25.pth"))
model
的加载是使用.load_state_dict()
torch.load("checkPoints/LLModel_0708_21:54:25.pth")
这个得到的是一个OrderDict
,包含了每一层的参数(有参数的参数层,像Pool层是没有的)
1 | model_path = "/home/joey/Documents/models" |