閱讀66 返回首頁    go 阿裏雲 go 技術社區[雲棲]


在TensorFlow中對比兩大生成模型:VAE與GAN(附測試代碼)

項目鏈接:https://github.com/kvmanohar22/ Generative-Models
變分自編碼器(VAE)與生成對抗網絡(GAN)是複雜分布上無監督學習最具前景的兩類方法。
本項目總結了使用變分自編碼器(Variational Autoencode,VAE)和生成對抗網絡(GAN)對給定數據分布進行建模,並且對比了這些模型的性能。你可能會問:我們已經有了數百萬張圖像,為什麼還要從給定數據分布中生成圖像呢?正如 Ian Goodfellow 在 NIPS 2016 教程中指出的那樣,實際上有很多應用。我覺得比較有趣的一種是使用 GAN 模擬可能的未來,就像強化學習中使用策略梯度的智能體那樣。
本文組織架構:

  • 變分自編碼器(VAE)
  • 生成對抗網絡(GAN)
  • 訓練普通 GAN 的難點
  • 訓練細節
  • 在 MNIST 上進行 VAE 和 GAN 對比實驗
    • 在無標簽的情況下訓練 GAN 判別器
    • 在有標簽的情況下訓練 GAN 判別器
  • 在 CIFAR 上進行 VAE 和 GAN 實驗
  • 延伸閱讀


VAE

變分自編碼器可用於對先驗數據分布進行建模。從名字上就可以看出,它包括兩部分:編碼器和解碼器。編碼器將數據分布的高級表征映射到數據的低級表征,低級表征叫作本征向量(latent vector)。解碼器吸收數據的低級表征,然後輸出同樣數據的高級表征。
從數學上來講,讓 X 作為編碼器的輸入,z 作為本征向量,X′作為解碼器的輸出。
圖 1 是 VAE 的可視化圖。

1

這與標準自編碼器有何不同?關鍵區別在於我們對本征向量的約束。如果是標準自編碼器,那麼我們主要關注重建損失(reconstruction loss),即:

2

而在變分自編碼器的情況中,我們希望本征向量遵循特定的分布,通常是單位高斯分布(unit Gaussian distribution),使下列損失得到優化:

3

p(z′)∼N(0,I) 中 I 指單位矩陣(identity matrx),q(z∣X) 是本征向量的分布,其中。和由神經網絡來計算。KL(A,B) 是分布 B 到 A 的 KL 散度。
由於損失函數中還有其他項,因此存在模型生成圖像的精度,同本征向量的分布與單位高斯分布的接近程度之間存在權衡(trade-off)。這兩部分由兩個超參數λ_1 和λ_2 來控製。

GAN

GAN 是根據給定的先驗分布生成數據的另一種方式,包括同時進行的兩部分:判別器和生成器。
判別器用於對「真」圖像和「偽」圖像進行分類,生成器從隨機噪聲中生成圖像(隨機噪聲通常叫作本征向量或代碼,該噪聲通常從均勻分布(uniform distribution)或高斯分布中獲取)。生成器的任務是生成可以以假亂真的圖像,令判別器也無法區分出來。也就是說,生成器和判別器是互相對抗的。判別器非常努力地嚐試區分真偽圖像,同時生成器盡力生成更加逼真的圖像,目的是使判別器將這些圖像也分類為「真」圖像。
圖 2 是 GAN 的典型結構。

4

生成器包括利用代碼輸出圖像的解卷積層。圖 3 是生成器的架構圖。

5


訓練 GAN 的難點

訓練 GAN 時我們會遇到一些挑戰,我認為其中最大的挑戰在於本征向量/代碼的采樣。代碼隻是從先驗分布中對本征變量的噪聲采樣。有很多種方法可以克服該挑戰,包括:使用 VAE 對本征變量進行編碼,學習數據的先驗分布。這聽起來要好一些,因為編碼器能夠學習數據分布,現在我們可以從分布中進行采樣,而不是生成隨機噪聲。

訓練細節

我們知道兩個分布 p(真實分布)和 q(估計分布)之間的交叉熵通過以下公式計算:

6

  • 對於二元分類:


7

  • 對於 GAN,我們假設分布的一半來自真實數據分布,一半來自估計分布,因此:

    8

    訓練 GAN 需要同時優化兩個損失函數。
    按照極小極大值算法:

    9

    這裏,判別器需要區分圖像的真偽,不管圖像是否包含真實物體,都沒有注意力。當我們在 CIFAR 上檢查 GAN 生成的圖像時會明顯看到這一點。
    我們可以重新定義判別器損失目標,使之包含標簽。這被證明可以提高主觀樣本的質量。如:在 MNIST 或 CIFAR-10(兩個數據集都有 10 個類別)。
    上述 Python 損失函數在 TensorFlow 中的實現:

    def VAE_loss(true_images, logits, mean, std):
      """
        Args:
          true_images : batch of input images
          logits      : linear output of the decoder network (the constructed images)
          mean        : mean of the latent code
          std         : standard deviation of the latent code
      """
      imgs_flat    = tf.reshape(true_images, [-1, img_h*img_w*img_d])
      encoder_loss = 0.5 * tf.reduce_sum(tf.square(mean)+tf.square(std)
                     -tf.log(tf.square(std))-1, 1)
      decoder_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                     logits=logits, labels=img_flat), 1)
      return tf.reduce_mean(encoder_loss + decoder_loss)
    def GAN_loss_without_labels(true_logit, fake_logit):
      """
        Args:
          true_logit : Given data from true distribution,
                      `true_logit` is the output of Discriminator (a column vector)
          fake_logit : Given data generated from Generator,
                      `fake_logit` is the output of Discriminator (a column vector)
      """
    
      true_prob = tf.nn.sigmoid(true_logit)
      fake_prob = tf.nn.sigmoid(fake_logit)
      d_loss = tf.reduce_mean(-tf.log(true_prob)-tf.log(1-fake_prob))
      g_loss = tf.reduce_mean(-tf.log(fake_prob))
      return d_loss, g_loss  
    def GAN_loss_with_labels(true_logit, fake_logit):
      """
        Args:
          true_logit : Given data from true distribution,
                      `true_logit` is the output of Discriminator (a matrix now)
          fake_logit : Given data generated from Generator,
                      `fake_logit` is the output of Discriminator (a matrix now)
      """
      d_true_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.labels, logits=self.true_logit, dim=1)
      d_fake_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=1-self.labels, logits=self.fake_logit, dim=1)
      g_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.labels, logits=self.fake_logit, dim=1)
    
      d_loss = d_true_loss + d_fake_loss      return tf.reduce_mean(d_loss), tf.reduce_mean(g_loss)
    


在 MNIST 上進行 VAE 與 GAN 對比實驗

1. 不使用標簽訓練判別器
我在 MNIST 上訓練了一個 VAE。代碼地址:https://github.com/kvmanohar22/Generative-Models
實驗使用了 MNIST 的 28×28 圖像,下圖中:

  • 左側:數據分布的 64 張原始圖像
  • 中間:VAE 生成的 64 張圖像
  • 右側:GAN 生成的 64 張圖像

第 1 次迭代:

10

第 2 次迭代:

11

第 3 次迭代:

12

第 4 次迭代:

13

第 100 次迭代:

14

VAE(125)和 GAN(368)訓練的最終結果:

15

根據GAN迭代次數生成的gif圖:

16

顯然,VAE 生成的圖像與 GAN 生成的圖像相比,前者更加模煳。這個結果在預料之中,因為 VAE 模型生成的所有輸出都是分布平均。為了減少圖像的模煳度,我們可以使用 L1 損失來代替 L2 損失。
在第一個實驗後,作者還將在近期研究使用標簽訓練判別器,並在 CIFAR 數據集上測試 VAE 與 GAN 的性能。
使用
下載 MNIST 和 CIFAR 數據集
使用 MNIST 訓練 VAE 請運行:

python main.py --train --model vae --dataset mnist

使用 MNIST 訓練 GAN 請運行:

python main.py --train --model gan --dataset mnist

想要獲取完整的命令行選項,請運行:

python main.py --help

該模型由 generate_frq 決定生成圖片的頻率,默認值為 1。

GAN 在 MNIST 上的訓練結果

MNIST 數據集中的樣本圖像:

17

上方是 VAE 生成的圖像,下方的圖展示了 GAN 生成圖像的過程:

18

原文發布時間為:2017-10-29
本文來自雲棲社區合作夥伴“數據派THU”,了解相關信息可以關注“數據派THU”微信公眾號

最後更新:2017-10-30 17:04:17

  上一篇:go  Enterprise Library深入解析與靈活應用(3):倘若將Unity、PIAB、Exception Handling引入MVP模式.. .. ..
  下一篇:go  [原創]WCF後續之旅(12): 線程關聯性(Thread Affinity)對WCF並發訪問的影響