经典GAN loss原理以及其实现

导言:前几天同门问起我GAN loss的实现,我发现自己在一些符号、细节上对GAN loss还是有没有记牢的地方。于是写下这篇blog来加深印象。

GANS loss

原理

经典GAN loss是最原始的loss:

这个loss是最开始提出的GANS loss,q(x)是真实数据的分布,p(z)是随机分布(一般是高斯)。

实现:

Dis 实现:

注意在深度学习框架中,都是使用Gradient Decent来寻找目标函数的最小值.

在优化Discriminator时我们需要将max换做min,即去相反数。

在Generator固定的情况下,上面等式等价于:

即计算真实图像和生成图像在Discriminator的输出负期望值。

BCEloss

BCEloss

Gen实现:

在目标函数中,第一项与G无关,所以在优化时可以去掉。

只用优化第二部分即可:

但是在原文中作者发现1-log D(G(z))的梯度在最开始时很小,在后面变大,这与我们训练思路相悖,应该开始大后面小。所以将其变为-log(D(G(z))

样例模板:

先更新Generator的模板

注意detach()操作的位置。detach函数可以切断反向梯度传播。
由于G,D都使用了不同的optim,所以在更新时可以不用detach。只是需要optimizer_D.zero_grad()将上一步的梯度清除。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
...
# 先更新generator的模板:
adversarial_loss = torch.nn.BCELoss()
generator = Generator()
discriminator = Discriminator()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
real_label = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
fake_label = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

# Configure input
real_imgs = Variable(imgs.type(Tensor))
optimizer_G.zero_grad()
...

# Generate a batch of images
gen_imgs = generator(z)

# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), real_label)
...

optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(real_imgs), real_label)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake_label)
d_loss = (real_loss + fake_loss) / 2

d_loss.backward()
optimizer_D.step()
...

先更新Discriminator的模板1

注意backward(retain_graph=True)的使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真实标签,都是1
fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假标签,都是0
#----------
# 训练判别器
#----------
real_imgs = imgs.to(device)
z = torch.randn((imgs.shape[0], 100)).to(device) # 噪声
gen_imgs = generator(z) # 从噪声中生成假数据
pred_gen = discriminator(gen_imgs) # 判别器对假数据的输出
pred_real = discriminator(real_imgs) # 判别器对真数据的输出
optimizer_D.zero_grad() # 把判别器中所有参数的梯度归零
real_loss = adversarial_loss(pred_real, valid) # 判别器对真实样本的损失
fake_loss = adversarial_loss(pred_gen, fake) # 判别器对假样本的损失
d_loss = (real_loss + fake_loss) / 2 # 两项损失相加取平均
# 下面这行代码十分重要,将在正文着重讲解
d_loss.backward(retain_graph=True) # retain_graph 十分重要,否则计算图内存将会被释放
optimizer_D.step() # 判别器参数更新
#---------
#训练生成器
#---------
g_loss = adversarial_loss(pred_gen, valid) # 生成器的损失函数
optimizer_G.zero_grad() # 生成器参数梯度归零
g_loss.backward() # 生成器的损失函数梯度反向传播
optimizer_G.step() # 生成器参数更新

Pytorch 中默认一个计算图只计算一次反向传播,反向传播后,这个计算图的内存就被释放了。而后面的 generator 算梯度时还要用到这个计算图,所以用这个参数控制计算图不能被释放。

在先更新Gen的代码中,以为Dis不会更新Gen的参数,所以不需要保留计算图。

先更新Discriminator的模板2

如果在更新G的参数时再计算一边discriminator(gen_imgs)也可以不用写retain_graph=True

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

gen_imgs = generator(z) # 从噪声中生成假数据
pred_gen = discriminator(gen_imgs) # 判别器对假数据的输出
pred_real = discriminator(real_imgs) # 判别器对真数据的输出
optimizer_D.zero_grad() # 把判别器中所有参数的梯度归零
real_loss = adversarial_loss(pred_real, valid) # 判别器对真实样本的损失
fake_loss = adversarial_loss(pred_gen, fake) # 判别器对假样本的损失
d_loss = (real_loss + fake_loss) / 2 # 两项损失相加取平均
# 下面这行代码十分重要,将在正文着重讲解
d_loss.backward() # 可以直接释放dis,因为后面又会再创建一次图
optimizer_D.step() # 判别器参数更新
#---------
#训练生成器
#---------
g_loss = adversarial_loss(discriminator(gen_imgs), valid) # 改变!!再计算一次得到计算图
optimizer_G.zero_grad() # 生成器参数梯度归零
g_loss.backward() # 生成器的损失函数梯度反向传播
optimizer_G.step() # 生成器参数更新

参考blog
retain_graph以及detach参考