导言:前几天同门问起我GAN loss的实现,我发现自己在一些符号、细节上对GAN loss还是有没有记牢的地方。于是写下这篇blog来加深印象。
GANS loss
原理
经典GAN loss是最原始的loss:
这个loss是最开始提出的GANS loss,q(x)
是真实数据的分布,p(z)
是随机分布(一般是高斯)。
实现:
Dis 实现:
注意在深度学习框架中,都是使用Gradient Decent来寻找目标函数的最小值.
在优化Discriminator时我们需要将max换做min,即去相反数。
在Generator固定的情况下,上面等式等价于:
即计算真实图像和生成图像在Discriminator的输出负期望值。
Gen实现:
在目标函数中,第一项与G无关,所以在优化时可以去掉。
只用优化第二部分即可:
但是在原文中作者发现1-log D(G(z))
的梯度在最开始时很小,在后面变大,这与我们训练思路相悖,应该开始大后面小。所以将其变为-log(D(G(z))
样例模板:
先更新Generator的模板
注意detach()
操作的位置。detach
函数可以切断反向梯度传播。
由于G,D都使用了不同的optim,所以在更新时可以不用detach。只是需要optimizer_D.zero_grad()
将上一步的梯度清除。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34...
# 先更新generator的模板:
adversarial_loss = torch.nn.BCELoss()
generator = Generator()
discriminator = Discriminator()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
real_label = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
fake_label = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
# Configure input
real_imgs = Variable(imgs.type(Tensor))
optimizer_G.zero_grad()
...
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), real_label)
...
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(real_imgs), real_label)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake_label)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
...
先更新Discriminator的模板1
注意backward(retain_graph=True)
的使用。
1 | valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真实标签,都是1 |
Pytorch 中默认一个计算图只计算一次反向传播,反向传播后,这个计算图的内存就被释放了。而后面的 generator 算梯度时还要用到这个计算图,所以用这个参数控制计算图不能被释放。
在先更新Gen的代码中,以为Dis不会更新Gen的参数,所以不需要保留计算图。
先更新Discriminator的模板2
如果在更新G的参数时再计算一边discriminator(gen_imgs)
也可以不用写retain_graph=True
1 |
|