导言:前几天同门问起我GAN loss的实现,我发现自己在一些符号、细节上对GAN loss还是有没有记牢的地方。于是写下这篇blog来加深印象。
GANS loss
原理
经典GAN loss是最原始的loss:
这个loss是最开始提出的GANS loss,q(x)
是真实数据的分布,p(z)
是随机分布(一般是高斯)。
实现:
Dis 实现:
注意在深度学习框架中,都是使用Gradient Decent来寻找目标函数的最小值.
在优化Discriminator时我们需要将max换做min,即去相反数。
在Generator固定的情况下,上面等式等价于:
即计算真实图像和生成图像在Discriminator的输出负期望值。
Gen实现:
在目标函数中,第一项与G无关,所以在优化时可以去掉。
只用优化第二部分即可:
但是在原文中作者发现1-log D(G(z))
的梯度在最开始时很小,在后面变大,这与我们训练思路相悖,应该开始大后面小。所以将其变为-log(D(G(z))
样例模板:
先更新Generator的模板
注意detach()
操作的位置。detach
函数可以切断反向梯度传播。
由于G,D都使用了不同的optim,所以在更新时可以不用detach。只是需要optimizer_D.zero_grad()
将上一步的梯度清除。
1 | ... |
先更新Discriminator的模板1
注意backward(retain_graph=True)
的使用。
1 | valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真实标签,都是1 |
Pytorch 中默认一个计算图只计算一次反向传播,反向传播后,这个计算图的内存就被释放了。而后面的 generator 算梯度时还要用到这个计算图,所以用这个参数控制计算图不能被释放。
在先更新Gen的代码中,以为Dis不会更新Gen的参数,所以不需要保留计算图。
先更新Discriminator的模板2
如果在更新G的参数时再计算一边discriminator(gen_imgs)
也可以不用写retain_graph=True
1 |
|