WGAN基本原理及Pytorch实现WGAN
作者:mmseoamin日期:2024-01-25

目录

1.WGAN产生背景

(1)超参数敏感

(2)模型崩塌

2.WGAN主要解决的问题

3.不同距离的度量方式

(1)方式一

(2)方式二

(3)方式三

(4)方式四

4.WGAN原理

(1)p和q分布下的距离计算 

(2)EM距离转换优化目标推导

(3)判别器和生成器的优化目标

5.WGAN训练算法 

6.WGAN网络结构

7.数据集下载

8.WGAN代码实现 

9.mainWindow窗口显示生成器生成的图片

10.模型下载 


GAN原理及Pytorch框架实现GAN(比较容易理解)

Pytorch框架实现DCGAN(比较容易理解)

CycleGAN的基本原理以及Pytorch框架实现

1.WGAN产生背景

        之所以会产生WGAN,主要是因为GAN网络模型训练困难的问题,其中主要体现在GAN模型对超参数比较敏感,需要精心挑选才能使模型训练起来,并且也会出现模式崩塌的现象。

(1)超参数敏感

        超参数敏感是指网络的结构设定,学习率,初始化状态等超参数对网络的训练过程影响比较大,微量的超参数调整将可能导致网络的训练结果截然不同。

WGAN基本原理及Pytorch实现WGAN,第1张

左图:表示使用WGAN算法训练的结果;

右图:表示标准的GAN在不使用Batch Normalization层导致网络训练不稳定,无法收敛,生成的样本与真实样本之间差距很大。

        为了更好的训练GAN网络,DCGAN论文的作者提出了不使用Pooling层,多使用Batch Normalization层,不使用全连接层,生成网络中激活函数应使用ReLU,最后一层使用tanh激活函数,判别网络激活函数应使用LeakReLU等一系列经验性的训练技巧。

        但是上面的技巧仅仅能在一定程度上避免出现训练不稳定的现象,并没有从理论上解释为什么会出现训练困难以及如何解决训练不稳定的问题。

(2)模型崩塌

        模型崩塌(Mode Collapse)是指模型生成的样本单一,多样性很差的现象。

        由于判别器只能鉴别单个样本是否为真实样本分布,并没有对多样性进行显式约束,导致生成模型可能倾向于生成真实分布的部分区间中的少量高质量样本,以此来在判别器中获得较高的概率值,而不会学习到全部的真实分布。

        模式崩塌在GAN的训练过程中比较常见。在训练过程中,通过可视化生成网络的样本,可以看到,生成的图片种类非常单一,生成网络总是倾向于生成某一种单一风格的样本图像。

2.WGAN主要解决的问题

  • 引入了一种新的分布距离度量方法:Wasserstein距离,也称为(Earth-Mover Distance)简称EM距离,表示从一个分布变换到另一个分布的最小代价。
  • 定义了一种称为Wasserstein GAN的GAN形式,该形式使EM距离的合理有效近似最小化,并且本文从理论上证明了相应的优化问题是合理的。
  • WGAN解决了GANs的主要训练问题。特别是,训练WGAN不需要维护在鉴别器和生成器的训练中保持谨慎的平衡,并且也不需要对网络架构进行仔细的设计。模式在GANs中典型的下降现象也显著减少。WGAN最引人注目的实际好处之一是能够通过训练鉴别器进行运算来连续地估计EM距离。绘制这些学习曲线不仅对调试和超参数搜索,但也与观察到的样品质量。

    3.不同距离的度量方式

    提示:下面的一些公式可能看起来很枯燥无味,但是如果读者可以坚持读完,将是不小的收获,而且下面给出的公式还是只是论文中推导公式的冰山一角。

    (1)方式一

    WGAN基本原理及Pytorch实现WGAN,第2张

    (2)方式二

    WGAN基本原理及Pytorch实现WGAN,第3张 

    (3)方式三

    WGAN基本原理及Pytorch实现WGAN,第4张 

    (4)方式四

    WGAN基本原理及Pytorch实现WGAN,第5张

     

    4.WGAN原理

    (1)p和q分布下的距离计算 

            导致GAN训练不稳定的原因是因为JS散度在不重叠的分布p和q上的梯度曲面是恒定为0,的。当分布p和q不重叠时,JS散度始终为0,从而导致此时GAN的训练梯度出现梯度弥散现象(或者梯度消失),参数长时间得不到更新,网络无法收敛。

    WGAN基本原理及Pytorch实现WGAN,第6张 

            可以看到上面结果给出,当两个分布完全不重叠时,无论分布之间的距离远近,JS散度为恒定值log2,此时JS散度将无法产生有效的梯度信息;当两个分布出现重叠时,JS散度才会平滑变动,产生有效梯度信息;当完全重叠之后,JS散度最小值为0.

    WGAN基本原理及Pytorch实现WGAN,第7张

    WGAN基本原理及Pytorch实现WGAN,第8张

    WGAN基本原理及Pytorch实现WGAN,第9张

            学习区分两个高斯时的最佳判别器(Discriminator)和critic。正如本文所看到的,极小极大GAN的鉴别器饱和并导致梯度消失。本文的WGANcritic在空间的所有部分都提供了非常平滑的渐变。

    (2)EM距离转换优化目标推导

     WGAN基本原理及Pytorch实现WGAN,第10张

    (3)判别器和生成器的优化目标

     WGAN基本原理及Pytorch实现WGAN,第11张

    WGAN基本原理及Pytorch实现WGAN,第12张 

    5.WGAN训练算法 

    WGAN基本原理及Pytorch实现WGAN,第13张

            WGAN基本原理及Pytorch实现WGAN,第14张

    具体实现代码如下:

    for epoch in range(NUM_EPOCHS):
        for batch_idx,(data,_) in enumerate(dataLoader):
            data = data.to(device)
            cur_batch_size = data.shape[0]
            #Train: Critic : max[critic(real)] - E[critic(fake)]
            loss_critic = 0
            for _ in range(CRITIC_ITERATIONS):
                noise = torch.randn(size = (cur_batch_size,Z_DIM,1,1),device=device)
                fake_img = gen(noise)
                #使用reshape主要是将最后的维度从[1,1,1,1]=>[1]
                critic_real = critic(data).reshape(-1)
                critic_fake = critic(fake_img).reshape(-1)
                loss_critic = (torch.mean(critic_real)- torch.mean(critic_fake))
                opt_critic.zero_grad()
                loss_critic.backward(retain_graph=True)
                opt_critic.step()
                #clip critic weight between -0.01 , 0.01
                for p in critic.parameters():
                    p.data.clamp_(-WEIGHT_CLIP,WEIGHT_CLIP)
            #将维度从[1,1,1,1]=>[1]
            gen_fake = critic(fake_img).reshape(-1)
            #max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
            loss_gen = -torch.mean(gen_fake)
            opt_gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()

    6.WGAN网络结构

    Pytorch框架实现DCGAN(比较容易理解)

     

    7.数据集下载

    链接:https://pan.baidu.com/s/1i_VU3aQpLkCx4Z5fhDVKHA 

    提取码:79y3

     

    8.WGAN代码实现 

    提示:代码放在了Github上,本文的代码是参考下面这位博主写的,但是自己其中只是做了一下修改,并且其中加了一个mainWindows界面代码,方便后面训练的模型进行图像风格的转换。

    参考博主的代码:https://b23.tv/QUc0CNb

    本文的代码下载:https://github.com/KeepTryingTo/Pytorch-GAN
    WGAN基本原理及Pytorch实现WGAN,第15张

     

    9.mainWindow窗口显示生成器生成的图片

    提示:这里编写了一个显示生成器显示图片的程序(mainWindow.py),加载之前训练之后保存的生成器模型,之后可使用该模型进行随机生成图片,如下:

    (1)运行mainWindow.py 初始界面如下

    WGAN基本原理及Pytorch实现WGAN,第16张

     点击随机生成图片:

    WGAN基本原理及Pytorch实现WGAN,第17张

     

    WGAN基本原理及Pytorch实现WGAN,第18张

     

     

    10.模型下载 

     链接:https://pan.baidu.com/s/1dBbz6yyaRHMHl6Dl5Q24Dg 

    提取码:6t7u

    参考文章:

    参考博主的代码:https://b23.tv/QUc0CNb

    《TensorFlow深度学习》