閱讀103 返回首頁    go 阿裏雲 go 技術社區[雲棲]


用GAN來做圖像生成,這是最好的方法

前言

在我們之前的文章中,我們學習了如何構造一個簡單的 GAN 來生成 MNIST 手寫圖片。對於圖像問題,卷積神經網絡相比於簡單地全連接的神經網絡更具優勢,因此,我們這一節我們將繼續深入 GAN,通過融合卷積神經網絡來對我們的 GAN 進行改進,實現一個深度卷積 GAN。如果還沒有親手實踐過 GAN 的小夥伴可以先去學習一下上一篇專欄:生成對抗網絡(GAN)之 MNIST 數據生成

專欄中的所有代碼都在我的 GitHub中,歡迎 star 與 fork。

本次代碼在 NELSONZHAO/zhihu/dcgan,裏麵包含了兩個文件:

  • dcgan_mnist:基於 MNIST 手寫數據集構造深度卷積 GAN 模型

  • dcgan_cifar:基於 CIFAR 數據集構造深度卷積 GAN 模型

本文主要以 MNIST 為例進行介紹,兩者在本質上沒有差別,隻在細微的參數上有所調整。由於窮學生資源有限,沒有對模型增加迭代次數,也沒有構造更深的模型。並且也沒有選取像素很高的圖像,高像素非常消耗計算量。本節隻是一個拋磚引玉的作用,讓大家了解 DCGAN 的結構,如果有資源的小夥伴可以自己去嚐試其他更清晰的圖片以及更深的結構,相信會取得很不錯的結果。

工具

  • Python3

  • TensorFlow 1.0

  • Jupyter notebook

正文

整個正文部分將包括以下部分:

- 數據加載

- 模型輸入

- Generator

- Discriminator

- Loss

- Optimizer

- 訓練模型

- 可視化

數據加載

數據加載部分采用 TensorFlow 中的 input_data 接口來進行加載。關於加載細節在前麵的文章中已經寫了很多次啦,相信看過我文章的小夥伴對 MNIST 加載也非常熟悉,這裏不再贅述。

模型輸入

在 GAN 中,我們的輸入包括兩部分,一個是真實圖片,它將直接輸入給 discriminator 來獲得一個判別結果;另一個是隨機噪聲,隨機噪聲將作為 generator 來生成圖片的材料,generator 再將生成圖片傳遞給 discriminator 獲得一個判別結果。

用GAN來做圖像生成,這是最好的方法

上麵的函數定義了輸入圖片與噪聲圖片兩個 tensor。

Generator

生成器接收一個噪聲信號,基於該信號生成一個圖片輸入給判別器。在上一篇專欄文章生成對抗網絡(GAN)之 MNIST 數據生成中,我們的生成器是一個全連接層的神經網絡,而本節我們將生成器改造為包含卷積結構的網絡,使其更加適合處理圖片輸入。整個生成器結構如下:

用GAN來做圖像生成,這是最好的方法

我們采用了 transposed convolution 將我們的噪聲圖片轉換為了一個與輸入圖片具有相同 shape 的生成圖像。我們來看一下具體的實現代碼:

用GAN來做圖像生成,這是最好的方法

上麵的代碼是整個生成器的實現細節,裏麵包含了一些 trick,我們來一步步地看一下。

首先我們通過一個全連接層將輸入的噪聲圖像轉換成了一個 1 x 4*4*512 的結構,再將其 reshape 成一個 [batch_size, 4, 4, 512] 的形狀,至此我們其實完成了第一步的轉換。接下來我們使用了一個對加速收斂及提高卷積神經網絡性能中非常有效的方法——加入 BN(batch normalization),它的思想是歸一化當前層輸入,使它們的均值為 0 和方差為 1,類似於我們歸一化網絡輸入的方法。它的好處在於可以加速收斂,並且加入 BN 的卷積神經網絡受權重初始化影響非常小,具有非常好的穩定性,對於提升卷積性能有很好的效果。關於 batch normalization,我會在後麵專欄中進行一個詳細的介紹。

完成 BN 後,我們使用 Leaky ReLU 作為激活函數,在上一篇專欄中我們已經提過這個函數,這裏不再贅述。最後加入 dropout 正則化。剩下的 transposed convolution 結構層與之類似,隻不過在最後一層中,我們不采用 BN,直接采用 tanh 激活函數輸出生成的圖片。

在上麵的 transposed convolution 中,很多小夥伴肯定會對每一層 size 的變化疑惑,在這裏來講一下在 TensorFlow 中如何來計算每一層 feature map 的 size。首先,在卷積神經網絡中,假如我們使用一個 k x k 的 filter 對 m x m x d 的圖片進行卷積操作,strides 為 s,在 TensorFlow 中,當我們設置 padding='same'時,卷積以後的每一個 feature map 的 height 和 width 為用GAN來做圖像生成,這是最好的方法;當設置 padding='valid'時,每一個 feature map 的 height 和 width 為用GAN來做圖像生成,這是最好的方法。那麼反過來,如果我們想要進行 transposed convolution 操作,比如將 7 x 7 的形狀變為 14 x 14,那麼此時,我們可以設置 padding='same',strides=2 即可,與 filter 的 size 沒有關係;而如果將 4 x 4 變為 7 x 7 的話,當設置 padding='valid'時,即用GAN來做圖像生成,這是最好的方法,此時 s=1,k=4 即可實現我們的目標。

上麵的代碼中我也標注了每一步 shape 的變化。

Discriminator

Discriminator 接收一個圖片,輸出一個判別結果(概率)。其實 Discriminator 完全可以看做一個包含卷積神經網絡的圖片二分類器。結構如下:

用GAN來做圖像生成,這是最好的方法

實現代碼如下:

用GAN來做圖像生成,這是最好的方法

上麵代碼其實就是一個簡單的卷積神經網絡圖像識別問題,最終返回 logits(用來計算 loss)與 outputs。這裏沒有加入池化層的原因在於圖片本身經過多層卷積以後已經非常小了,並且我們加入了 batch normalization 加速了訓練,並不需要通過 max pooling 來進行特征提取加速訓練。

Loss Function

用GAN來做圖像生成,這是最好的方法

Loss 部分分別計算 Generator 的 loss 與 Discriminator 的 loss,和之前一樣,我們加入 label smoothing 防止過擬合,增強泛化能力。

Optimizer

GAN 中實際包含了兩個神經網絡,因此對於這兩個神經網絡要分開進行優化。代碼如下:

用GAN來做圖像生成,這是最好的方法

這裏的 Optimizer 和我們之前不同,由於我們使用了 TensorFlow 中的 batch normalization 函數,這個函數中有很多 trick 要注意。首先我們要知道,batch normalization 在訓練階段與非訓練階段的計算方式是有差別的,這也是為什麼我們在使用 batch normalization 過程中需要指定 training 這個參數。上麵使用 tf.control_dependencies 是為了保證在訓練階段能夠一直更新 moving averages。具體參考A Gentle Guide to Using Batch Normalization in Tensorflow - Rui Shu

訓練

到此為止,我們就完成了深度卷積 GAN 的構造,接著我們可以對我們的 GAN 來進行訓練,並且定義一些輔助函數來可視化迭代的結果。代碼太長就不放上來了,可以直接去我的 GitHub 下載。

我這裏隻設置了 5 輪 epochs,每隔 100 個 batch 打印一次結果,每一行代表同一個 epoch 下的 25 張圖:

用GAN來做圖像生成,這是最好的方法

我們可以看出僅僅經過了少部分的迭代就已經生成非常清晰的手寫數字,並且訓練速度是非常快的。

用GAN來做圖像生成,這是最好的方法

上麵的圖是最後幾次迭代的結果。我們可以回顧一下上一篇的一個簡單的全連接層的 GAN,收斂速度明顯不如深度卷積 GAN。

總結

到此為止,我們學習了一個深度卷積 GAN,並且看到相比於之前簡單的 GAN 來說,深度卷積 GAN 的性能更加優秀。當然除了 MNST 數據集以外,小夥伴兒們還可以嚐試很多其他圖片,比如我們之前用到過的 CIFAR 數據集,我在這裏也實現了一個 CIFAR 數據集的圖片生成,我隻選取了馬的圖片進行訓練:

剛開始訓練時:

用GAN來做圖像生成,這是最好的方法

訓練 50 個 epochs:

用GAN來做圖像生成,這是最好的方法

這裏我隻設置了 50 次迭代,可以看到最後已經生成了非常明顯的馬的圖像,可見深度卷積 GAN 的優勢。


我的 GitHub:NELSONZHAO (Nelson Zhao)

上麵包含了我的專欄中所有的代碼實現,歡迎 star,歡迎 fork。


====================================分割線================================



本文作者:AI研習社

本文轉自雷鋒網禁止二次轉載,原文鏈接

最後更新:2017-08-23 10:33:32

  上一篇:go  在 Mac OS X 裝不上 TensorFlow?看了這篇就會裝
  下一篇:go  怡海軟件:SaaS是什麼?