详解 GAN 生成对抗网络

  • 小编 发布于 2019-12-08 07:00:12
  • 栏目:科技
  • 来源:Alice机器学习干货铺
  • 7686 人围观

GAN : Generative adversarial network 生成对抗网络

详解 GAN 生成对抗网络

https://www.kdnuggets.com/2017/01/generative-adver

Yan Lecun 给这个模型很高评价,认为它是机器学习领域紧十年来最酷的模型。

关于 GAN 的论文就有好多,下面这个repo里面比较全的列出了相关论文:

https://github.com/hindupuravinash/the-gan-zoo/blob/master/README.md

从2017年开始关于它的论文每个月都在不断大幅增长:

图片来源:

https://deephunt.in/the-gan-zoo-79597dc8c347


GAN 主要是用来生成东西,

在图像领域是生成图像,给它一个随机的向量,这个向量的每个元素一般来说是代表图像的一种特征,输入给模型后,它可以生成一张图片,也就是一个高维向量,向量的每个维度对应一个像素的颜色,

在自然语言处理领域是生成文字,比如说写诗写文章,给它一个随机向量,它可以输出一句话。


GAN 模型包括一个 generator 生成器和一个 discriminator 辨别器,生成器和辨别器之间的关系就好像是被捕食者和捕食者的关系。:


详解 GAN 生成对抗网络


首先看 discriminator,实际上它是一个神经网络。

它的输入是一张图片,输出是一个标量,这个值代表图片的质量,数值越大,生成的结果质量越高,所以,当输入图片很真实的时候,discriminator 给它的得分越高,相反则得分越低。

再看 generator,它也是一个神经网络,开始它的输入是随机的,因为它也不知道要怎么生成图片,所以一开始的输出也是比较模糊的东西。

生成器生成的结果,辨别器要做的就是判断这张图片是由生成器生成的还是像是真实的图片,辨别器会给评分,评分低的话会被督促着进化,生成更好地结果,就像被捕食者为了不被灭族就要进化,但是捕食者为了不被饿着也会进化,就这样互相督促着一点点改进,最后会生成非常好的结果。

所以第一代生成器生成第一代结果,第一代辨别器评分,然后第二代生成器要做的事,就是想办法骗过第一代的辨别器。例如,第一代的辨别器通过是否有颜色这个特征来区分真实图片和第一代生成器生成的图片,那么第二代生成器为了骗过第一代的辨别器就会给图片加上了颜色。

同样第二代的辨别器也跟着进化,它要判断真实图片和第二代生成器生成的图片,这时候不能根据是否有颜色了,而是通过其他特征,例如是否有嘴巴。

就这样生成器和辨别器之间的关系就像是相互对抗的天敌,经过不断地进化,生成器就可以生成更高质量更接近真实的图片。


GAN 模型的算法过程

生成器和辨别器都是神经网络,训练模型之前先随机生成它们的参数,然后进行迭代去训练生成器和辨别器。

在每个迭代中有两个步骤:

第一步,固定生成器的参数,只去训练辨别器的参数。

  • 具体做法是将一些随机向量投给生成器,生成器就会生成一些效果不好的图片,
  • 然后从真实图片库中采样一些样本,
  • 接着要去训练辨别器的参数,

方法就是,如果这个图片是从真实数据集合中采用出来的,就给高分,如果是生成器生成的,就给低分,这可以是一个分类问题。

第二步,固定辨别器,只去训练生成器的参数。

  • 先把一个向量输入给生成器,会生成一个图片,
  • 接着将这个图片输入给辨别器,辨别器会给这个图片一个分数,
  • 因为生成器的目的是要骗过辨别器,所以希望得到的分数可以越高越好,相当于生成的图片过了辨别器这一关,生成了比较真实的图片,也就是这时候要固定辨别器的参数,去调节生成器的参数。

在实际训练时,会将生成器和辨别器放在一起,组成一个大的神经网络。

例如,生成器和辨别器都有五层,将它们连在一起成为一个十层的网络,

输入是一个向量,输出是一个值,中间有一层输出代表一个图片,这一层会特别宽,和图片的展开纬度是一样的。

在训练的时候,先固定后面五层隐藏层,只去训练前面五层,就是在训练生成器,目标就是要让整个网络的输出值越大越好。


详解 GAN 生成对抗网络


---

下面这个图就是 GAN 的详细算法:


详解 GAN 生成对抗网络


接下来对照算法详细讲解,

生成器的参数是 theta g,辨别器的参数是 theta d。

  • 在每次迭代中,先从数据库中采样出 m 个图片,
  • 再从一个分布中采样出 m 个噪音样本向量,这个分布可以是高斯分布,
  • x ~ 表示生成器生成的图片,
  • 然后去调整辨别器,
  • 前一部分是训练辨别器,目的是要让这个目标函数越大越好,

目标函数的意义是,首先拿出 m 张真实的图片,给辨别器得到一个分数,取 log 对数,在做平均,因为目标是要让这个目标函数越大越好,就是让第一项越大越好,也就是让他给真实的图片的得分越大越好。

目标函数第二项的含义是将生成器生成的这些假的图片传递给辨别器,经过 sigmoid 函数,得到 0~1 之间的数,同样也是取对数,求平均。

因为整体是希望这个目标函数越大越好,那么第二项中的 D(X) 就是需要越小越好,也就是生成器生成的图片得分越小越好。

至于如何让这个目标函数越大越好,就可以用梯度算法等优化算法来更新参数。

算法的上面这部分代码是要训练辨别器,下面这部分是要训练生成器。

  • 首先采样出 m 个向量,这些向量和前面的采样的向量是不一样的,
  • 生成器的目标是想办法骗过辨别器,V ~ 就是生成器的目标函数,

把 m 个向量输入给生成器,生成一张图片就是 G(Z),

把这个图片丢给辨别器,得到的是 D(G(Z)),同样取对数,求平均值。

最终的目标是希望这个函数越大越好,意思是生成器生成的图片输入给辨别器之后得到的分数可以越大越好。

同样可以用梯度的算法来调节参数,使得目标函数越大越好。

这样上面一部分是训练辨别器,下面是训练生成器,就这样反复交替地执行这两个步骤。


推荐学习资料:

https://youtu.be/DQNNMiAP5lw

本文是 李宏毅 GAN Lecture 1 Introduction 的学习笔记。

转载请说明出处:866热点网 ©