閱讀98 返回首頁    go 阿裏雲 go 技術社區[雲棲]


NVIDIA新作解讀:用GAN生成前所未有的高清圖像(附PyTorch複現) | PaperDaily #15

今天要介紹的文章是 NVIDIA 投稿 ICLR 2018 的一篇文章,Progressive Growing of GANs for Improved Quality, Stability, and Variation[1],姑且稱它為 PG-GAN。

從行文可以看出文章是臨時趕出來的,畢竟這麼大的實驗,用 P100 都要跑 20 天,更不用說調參時間了,不過人家在 NVIDIA,不缺卡。作者放出了基於 Lasagna 的代碼,今天我也會簡單解讀一下代碼。另外,我也在用 PyTorch 做複現。

在 PG-GAN 出來以前,訓練高分辨率圖像生成的 GAN 方法主要就是 LAPGAN[2] 和 BEGAN[6]。後者主要是針對人臉的,生成的人臉逼真而不會是鬼臉。

這裏也提一下,生成鬼臉的原因是 Discriminator 不再更新,它不能再給予 Generator 其他指導,Generator 找到了一種騙過 Discriminator 的方法,也就是生成鬼臉,而且很大可能會 mode collapse。

下圖是我用 PyTorch 做的 BEGAN 複現,當時沒有跑很高的分辨率,但是效果確實比其他 GAN 好基本沒有鬼臉。

a11d32d72aa7891e926879814d84509f2c467cbf

PG-GAN 能夠穩定地訓練生成高分辨率的 GAN。我們來看一下 PG-GAN 跟別的 GAN 不同在哪裏。

1. 訓練方式

作者采用 progressive growing 的訓練方式,先訓一個小分辨率的圖像生成,訓好了之後再逐步過渡到更高分辨率的圖像。然後穩定訓練當前分辨率,再逐步過渡到下一個更高的分辨率。

ab5376e507b284de3bf741db9873c2c5015ea801

如上圖所示。更具體點來說,當處於 fade in(或者說 progressive growing)階段的時候,上一分辨率(4x4)會通過 resize+conv 操作得到跟下一分辨率(8x8)同樣大小的輸出,然後兩部分做加權,再通過 to_rgb 操作得到最終的輸出。

這樣做的一個好處是它可以充分利用上個分辨率訓練的結果,通過緩慢的過渡(w 逐漸增大),使得訓練生成下一分辨率的網絡更加穩定。

上麵展示的是 Generator 的 growing 階段。下圖是 Discriminator 的 growing,它跟 Generator 的類似,差別在於一個是上采樣,一個是下采樣。這裏就不再贅述。

b1528b2f1eb484b0cea3c9634fe42a666adf1d06

不難想象,網絡在 growing 的時候,如果不引入 progressive (fade in),那麼有可能因為比較差的初始化,導致原來訓練的進度功虧一簣,模型不得不從新開始學習,如此一來就沒有充分利用以前學習的成果,甚至還可能誤導。我們知道 GAN的訓練不穩定,這樣的突變有時候是致命的。所以 fade in 對訓練的穩定性來說至關重要。

說到 growing 的訓練方式,我們很容易想到 autoencoder 也有一種類似的訓練方式:先訓各一層的 encoder 和 decoder,訓好了以後再過渡到訓練各兩層的 encoder 和 decoder,這樣的好處是避免梯度消失,導致離 loss 太遠的層更新不夠充分。PG-GAN 的做法可以說是這種 autoencoder 訓練方式在 GAN 訓練上的應用。

此外,訓練 GAN 生成高分辨率圖像,還有一種方法,叫 LAPGAN[2]。LAPGAN 借助 CGAN,高分辨率圖像的生成是以低分辨率圖像作為條件去生成殘差,然後低分辨率圖上采樣跟殘差求和得到高分辨率圖,通過不斷堆疊 CGAN 得到我們想要的分辨率。

077599d3eef33e0e6a23e8f03324721c0e051dc3

LAPGAN 是多個 CGAN 堆疊一起訓練,當然可以拆分成分階段訓練,但是它們本質上是不同的,LAPGAN 學的是殘差,而 PG-GAN 存在 stabilize 訓練階段,學的不是殘差,而直接是圖像。

作者在代碼中設計了一個 LODSelectLayer 來實現 progressive growing。對於 Generator,每一層插入一個 LODSelectLayer,它實際上就是一個輸出分支,實現在特定層的輸出。

從代碼來看,作者應該是這樣訓練的(參見這裏的 train_gan 函數),先構建 4x4 分辨率的網絡,訓練,然後把網絡存出去。再構建 8x8 分辨率的網絡,導入原來 4x4 的參數,然後訓 fade in,再訓 stabilize,再存出去。我在複現的時候,根據文章的意思,修改了 LODSelectLayer 層,因為 PyTorch 是動態圖,能夠很方便地寫 if-else 邏輯語句。

借助這種 growing 的方式,PG-GAN 的效果超級好。另外,我認為這種 progressive growing 的方法比較適合 GAN 的訓練,GAN 訓練不穩定可以通過 growing 的方式可以緩解。

不隻是在噪聲生成圖像的任務中可以這麼做,在其他用到 GAN 的任務中都可以引入這種訓練方式。我打算將 progressive growing 引入到 CycleGAN 中,希望能夠得到更好的結果。

2. 增加生成多樣性

增加生成樣本的多樣性有兩種可行的方法:通過 loss 讓網絡自己調整、通過設計判別多樣性的特征人為引導。

WGAN 屬於前者,它采用更好的分布距離的估計(Wasserstein distance)。模型收斂意味著生成的分布和真實分布一致,能夠有多樣性的保證。PG-GAN 則屬於後者。

作者沿用 improved GAN 的思路,通過人為地給 Discriminator 構造判別多樣性的特征來引導 Generator 生成更多樣的樣本。Discriminator 能探測到 mode collapse 是否產生了,一旦產生,Generator 的 loss 就會增大,通過優化 Generator 就會往遠離 mode collapse 的方向走,而不是一頭栽進坑裏。

Improved GAN 引入了 minibatch discrimination 層,構造一個 minibatch 內的多樣性衡量指標。它引入了新的參數。

2d5f7e6d2b7e6fde41bf93ba8488951a26d3d07c

而 PG-GAN 不引入新的參數,利用特征的標準差作為衡量標準。

84150e86ed99fe9f40758caefee4152cb3cd6fbb

這裏囉嗦地說明上麵那張圖做了什麼。我們有 N 個樣本的 feature maps(為了畫圖方便,不妨假設每個樣本隻有一個 feature map),我們對每個空間位置求標準差,用 numpy 的 std 函數來說就是沿著樣本的維度求 std。這樣就得到一張新的 feature map(如果樣本的 feature map 不止一個,那麼這樣構造得到的 feature map 數量應該是一致的),接著 feature map 求平均得到一個數。

這個過程簡單來說就是求 mean std,作者把這個數複製成一張 feature map 的大小,跟原來的 feature map 拚在一起送給 Discriminator。

從作者放出來的代碼來看,這對應 averaging=“all”的情況。作者還嚐試了其他的統計量:“spatial”,“gpool”,“flat”等。它們的主要差別在於沿著哪些維度求標準差。至於它們的作用,等我的代碼複現完成了會做一個測試。估計作者調參發現“all”的效果最好。

3. Normalization

從 DCGAN[3]開始,GAN 的網絡使用 batch (or instance) normalization 幾乎成為慣例。使用 batch norm 可以增加訓練的穩定性,大大減少了中途崩掉的情況。作者采用了兩種新的 normalization 方法,不引入新的參數(不引入新的參數似乎是 PG-GAN 各種 tricks 的一個賣點)。

第一種 normalization 方法叫 pixel norm,它是 local response normalization 的變種。Pixel norm 沿著 channel 維度做歸一化,這樣歸一化的一個好處在於,feature map 的每個位置都具有單位長度。這個歸一化策略與作者設計的 Generator 輸出有較大關係,注意到 Generator 的輸出層並沒有 Tanh 或者 Sigmoid 激活函數,後麵我們針對這個問題進行探討。

48abe70aba74ac035c7ccd01da9b30d9d8b7d326

第二種 normalization 方法跟凱明大神的初始化方法[4]掛鉤。He 的初始化方法能夠確保網絡初始化的時候,隨機初始化的參數不會大幅度地改變輸入信號的強度。

ce7af75de8064cfcadd9f8c13dd1f3383b463c7b

根據這個式子,我們可以推導出網絡每一層的參數應該怎樣初始化。可以參考 PyTorch 提供的接口。

作者走得比這個要遠一點,他不隻是初始化的時候對參數做了調整,而是動態調整。初始化采用標準高斯分布,但是每次迭代都會對 weights 按照上麵的式子做歸一化。作者 argue 這樣的歸一化的好處在於它不用再擔心參數的 scale 問題,起到均衡學習率的作用(euqalized learning rate)。

4. 有針對性地給樣本加噪聲

通過給真實樣本加噪聲能夠起到均衡 Generator 和 Discriminator 的作用,起到緩解 mode collapse 的作用,這一點在 WGAN 的前傳中就已經提到[5]。盡管使用 LSGAN 會比原始的 GAN 更容易訓練,然而它在 Discriminator 的輸出接近 1 的適合,梯度就消失,不能給 Generator 起到引導作用。

針對 D 趨近 1 的這種特性,作者提出了下麵這種添加噪聲的方式:

7d463232346e686eb1bd53732134b014eb698b33

其中,?tp=webp&wxfrom=5&wx_lazy=1分別為第 t 次迭代判別器輸出的修正值、第 t-1 次迭代真樣本的判別器輸出。 

從式子可以看出,當真樣本的判別器輸出越接近 1 的時候,噪聲強度就越大,而輸出太小(<=0.5)的時候,不引入噪聲,這是因為 0.5 是 LSGAN 收斂時,D 的合理輸出(無法判斷真假樣本),而小於 0.5 意味著 D 的能力太弱。

文章還有其他很多 tricks,有些 tricks 不是作者提出的,如 Layer norm,還有一些比較細微的 tricks,比如每個分辨率訓練好做 sample 的時候學習率怎麼 decay,每個分辨率的訓練迭代多少次等等,我們就不再詳細展開。具體可以參見官方代碼,也可以看我複現的代碼。

目前複現的結果還在跑,現在訓練到了 16x16 分辨率的 fade in 階段,放一張當前的結果圖,4 個方格的每個方格左邊 4 列是生成的圖,右邊 4 列是真實樣本。現在還處於訓練早期,分辨率太低,過幾天看一下高分辨率的結果。

7518816bea365273c824fe84896ee9fdd106b1c2

5. 相關代碼

官方 Lasagna 代碼:

https://github.com/tkarras/progressive_growing_of_gans

作者 PyTorch 複現:

https://github.com/github-pengge/PyTorch-progressive_growing_of_gans

6. 參考文獻

[1]. Karras T, Aila T, Laine S, et al. Progressive Growing of GANs for Improved Quality, Stability, and Variation[J]. arXiv preprint arXiv:1710.10196, 2017.

[2]. Denton E L, Chintala S, Fergus R. Deep Generative Image Models using a Laplacian Pyramid of Adversarial Networks[C]//Advances in neural information processing systems. 2015: 1486-1494.

[3]. Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks[J]. arXiv preprint arXiv:1511.06434, 2015.

[4]. He K, Zhang X, Ren S, et al. Delving deep into rectifiers: Surpassing human-level performance on imagenet classification[C]//Proceedings of the IEEE international conference on computer vision. 2015: 1026-1034.

[5]. Arjovsky M, Bottou L. Towards principled methods for training generative adversarial networks[J]. arXiv preprint arXiv:1701.04862, 2017.

[6]. Berthelot D, Schumm T, Metz L. Began: Boundary equilibrium generative adversarial networks[J]. arXiv preprint arXiv:1703.10717, 2017.



本文來自雲棲社區合作夥伴“PaperWeekly”,了解相關信息可以關注“PaperWeekly”微信公眾號

最後更新:2017-11-16 15:05:25

  上一篇:go  WiFi萬能鑰匙發布《2017年上半年中國公共WiFi安全報告》:國內風險熱點占比0.81%
  下一篇:go  WiFi萬能鑰匙首席安全官龔蔚:對公共WiFi不必談虎色變