NVIDIA新作解讀:用GAN生成前所未有的高清圖像(附PyTorch複現) | PaperDaily #15
今天要介紹的文章是 NVIDIA 投稿 ICLR 2018 的一篇文章,Progressive Growing of GANs for Improved Quality, Stability, and Variation[1],姑且稱它為 PG-GAN。
從行文可以看出文章是臨時趕出來的,畢竟這麼大的實驗,用 P100 都要跑 20 天,更不用說調參時間了,不過人家在 NVIDIA,不缺卡。作者放出了基於 Lasagna 的代碼,今天我也會簡單解讀一下代碼。另外,我也在用 PyTorch 做複現。
在 PG-GAN 出來以前,訓練高分辨率圖像生成的 GAN 方法主要就是 LAPGAN[2] 和 BEGAN[6]。後者主要是針對人臉的,生成的人臉逼真而不會是鬼臉。
這裏也提一下,生成鬼臉的原因是 Discriminator 不再更新,它不能再給予 Generator 其他指導,Generator 找到了一種騙過 Discriminator 的方法,也就是生成鬼臉,而且很大可能會 mode collapse。
下圖是我用 PyTorch 做的 BEGAN 複現,當時沒有跑很高的分辨率,但是效果確實比其他 GAN 好基本沒有鬼臉。
PG-GAN 能夠穩定地訓練生成高分辨率的 GAN。我們來看一下 PG-GAN 跟別的 GAN 不同在哪裏。
1. 訓練方式
作者采用 progressive growing 的訓練方式,先訓一個小分辨率的圖像生成,訓好了之後再逐步過渡到更高分辨率的圖像。然後穩定訓練當前分辨率,再逐步過渡到下一個更高的分辨率。
如上圖所示。更具體點來說,當處於 fade in(或者說 progressive growing)階段的時候,上一分辨率(4x4)會通過 resize+conv 操作得到跟下一分辨率(8x8)同樣大小的輸出,然後兩部分做加權,再通過 to_rgb 操作得到最終的輸出。
這樣做的一個好處是它可以充分利用上個分辨率訓練的結果,通過緩慢的過渡(w 逐漸增大),使得訓練生成下一分辨率的網絡更加穩定。
上麵展示的是 Generator 的 growing 階段。下圖是 Discriminator 的 growing,它跟 Generator 的類似,差別在於一個是上采樣,一個是下采樣。這裏就不再贅述。
不難想象,網絡在 growing 的時候,如果不引入 progressive (fade in),那麼有可能因為比較差的初始化,導致原來訓練的進度功虧一簣,模型不得不從新開始學習,如此一來就沒有充分利用以前學習的成果,甚至還可能誤導。我們知道 GAN的訓練不穩定,這樣的突變有時候是致命的。所以 fade in 對訓練的穩定性來說至關重要。
說到 growing 的訓練方式,我們很容易想到 autoencoder 也有一種類似的訓練方式:先訓各一層的 encoder 和 decoder,訓好了以後再過渡到訓練各兩層的 encoder 和 decoder,這樣的好處是避免梯度消失,導致離 loss 太遠的層更新不夠充分。PG-GAN 的做法可以說是這種 autoencoder 訓練方式在 GAN 訓練上的應用。
此外,訓練 GAN 生成高分辨率圖像,還有一種方法,叫 LAPGAN[2]。LAPGAN 借助 CGAN,高分辨率圖像的生成是以低分辨率圖像作為條件去生成殘差,然後低分辨率圖上采樣跟殘差求和得到高分辨率圖,通過不斷堆疊 CGAN 得到我們想要的分辨率。
LAPGAN 是多個 CGAN 堆疊一起訓練,當然可以拆分成分階段訓練,但是它們本質上是不同的,LAPGAN 學的是殘差,而 PG-GAN 存在 stabilize 訓練階段,學的不是殘差,而直接是圖像。
作者在代碼中設計了一個 LODSelectLayer 來實現 progressive growing。對於 Generator,每一層插入一個 LODSelectLayer,它實際上就是一個輸出分支,實現在特定層的輸出。
從代碼來看,作者應該是這樣訓練的(參見這裏的 train_gan 函數),先構建 4x4 分辨率的網絡,訓練,然後把網絡存出去。再構建 8x8 分辨率的網絡,導入原來 4x4 的參數,然後訓 fade in,再訓 stabilize,再存出去。我在複現的時候,根據文章的意思,修改了 LODSelectLayer 層,因為 PyTorch 是動態圖,能夠很方便地寫 if-else 邏輯語句。
借助這種 growing 的方式,PG-GAN 的效果超級好。另外,我認為這種 progressive growing 的方法比較適合 GAN 的訓練,GAN 訓練不穩定可以通過 growing 的方式可以緩解。
不隻是在噪聲生成圖像的任務中可以這麼做,在其他用到 GAN 的任務中都可以引入這種訓練方式。我打算將 progressive growing 引入到 CycleGAN 中,希望能夠得到更好的結果。
2. 增加生成多樣性
增加生成樣本的多樣性有兩種可行的方法:通過 loss 讓網絡自己調整、通過設計判別多樣性的特征人為引導。
WGAN 屬於前者,它采用更好的分布距離的估計(Wasserstein distance)。模型收斂意味著生成的分布和真實分布一致,能夠有多樣性的保證。PG-GAN 則屬於後者。
作者沿用 improved GAN 的思路,通過人為地給 Discriminator 構造判別多樣性的特征來引導 Generator 生成更多樣的樣本。Discriminator 能探測到 mode collapse 是否產生了,一旦產生,Generator 的 loss 就會增大,通過優化 Generator 就會往遠離 mode collapse 的方向走,而不是一頭栽進坑裏。
Improved GAN 引入了 minibatch discrimination 層,構造一個 minibatch 內的多樣性衡量指標。它引入了新的參數。
而 PG-GAN 不引入新的參數,利用特征的標準差作為衡量標準。
這裏囉嗦地說明上麵那張圖做了什麼。我們有 N 個樣本的 feature maps(為了畫圖方便,不妨假設每個樣本隻有一個 feature map),我們對每個空間位置求標準差,用 numpy 的 std 函數來說就是沿著樣本的維度求 std。這樣就得到一張新的 feature map(如果樣本的 feature map 不止一個,那麼這樣構造得到的 feature map 數量應該是一致的),接著 feature map 求平均得到一個數。
這個過程簡單來說就是求 mean std,作者把這個數複製成一張 feature map 的大小,跟原來的 feature map 拚在一起送給 Discriminator。
從作者放出來的代碼來看,這對應 averaging=“all”的情況。作者還嚐試了其他的統計量:“spatial”,“gpool”,“flat”等。它們的主要差別在於沿著哪些維度求標準差。至於它們的作用,等我的代碼複現完成了會做一個測試。估計作者調參發現“all”的效果最好。
3. Normalization
從 DCGAN[3]開始,GAN 的網絡使用 batch (or instance) normalization 幾乎成為慣例。使用 batch norm 可以增加訓練的穩定性,大大減少了中途崩掉的情況。作者采用了兩種新的 normalization 方法,不引入新的參數(不引入新的參數似乎是 PG-GAN 各種 tricks 的一個賣點)。
第一種 normalization 方法叫 pixel norm,它是 local response normalization 的變種。Pixel norm 沿著 channel 維度做歸一化,這樣歸一化的一個好處在於,feature map 的每個位置都具有單位長度。這個歸一化策略與作者設計的 Generator 輸出有較大關係,注意到 Generator 的輸出層並沒有 Tanh 或者 Sigmoid 激活函數,後麵我們針對這個問題進行探討。
第二種 normalization 方法跟凱明大神的初始化方法[4]掛鉤。He 的初始化方法能夠確保網絡初始化的時候,隨機初始化的參數不會大幅度地改變輸入信號的強度。
根據這個式子,我們可以推導出網絡每一層的參數應該怎樣初始化。可以參考 PyTorch 提供的接口。
作者走得比這個要遠一點,他不隻是初始化的時候對參數做了調整,而是動態調整。初始化采用標準高斯分布,但是每次迭代都會對 weights 按照上麵的式子做歸一化。作者 argue 這樣的歸一化的好處在於它不用再擔心參數的 scale 問題,起到均衡學習率的作用(euqalized learning rate)。
4. 有針對性地給樣本加噪聲
通過給真實樣本加噪聲能夠起到均衡 Generator 和 Discriminator 的作用,起到緩解 mode collapse 的作用,這一點在 WGAN 的前傳中就已經提到[5]。盡管使用 LSGAN 會比原始的 GAN 更容易訓練,然而它在 Discriminator 的輸出接近 1 的適合,梯度就消失,不能給 Generator 起到引導作用。
針對 D 趨近 1 的這種特性,作者提出了下麵這種添加噪聲的方式:
其中,分別為第 t 次迭代判別器輸出的修正值、第 t-1 次迭代真樣本的判別器輸出。
從式子可以看出,當真樣本的判別器輸出越接近 1 的時候,噪聲強度就越大,而輸出太小(<=0.5)的時候,不引入噪聲,這是因為 0.5 是 LSGAN 收斂時,D 的合理輸出(無法判斷真假樣本),而小於 0.5 意味著 D 的能力太弱。
文章還有其他很多 tricks,有些 tricks 不是作者提出的,如 Layer norm,還有一些比較細微的 tricks,比如每個分辨率訓練好做 sample 的時候學習率怎麼 decay,每個分辨率的訓練迭代多少次等等,我們就不再詳細展開。具體可以參見官方代碼,也可以看我複現的代碼。
目前複現的結果還在跑,現在訓練到了 16x16 分辨率的 fade in 階段,放一張當前的結果圖,4 個方格的每個方格左邊 4 列是生成的圖,右邊 4 列是真實樣本。現在還處於訓練早期,分辨率太低,過幾天看一下高分辨率的結果。
5. 相關代碼
官方 Lasagna 代碼:
https://github.com/tkarras/progressive_growing_of_gans
作者 PyTorch 複現:
https://github.com/github-pengge/PyTorch-progressive_growing_of_gans
6. 參考文獻
[1]. Karras T, Aila T, Laine S, et al. Progressive Growing of GANs for Improved Quality, Stability, and Variation[J]. arXiv preprint arXiv:1710.10196, 2017.
[2]. Denton E L, Chintala S, Fergus R. Deep Generative Image Models using a Laplacian Pyramid of Adversarial Networks[C]//Advances in neural information processing systems. 2015: 1486-1494.
[3]. Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks[J]. arXiv preprint arXiv:1511.06434, 2015.
[4]. He K, Zhang X, Ren S, et al. Delving deep into rectifiers: Surpassing human-level performance on imagenet classification[C]//Proceedings of the IEEE international conference on computer vision. 2015: 1026-1034.
[5]. Arjovsky M, Bottou L. Towards principled methods for training generative adversarial networks[J]. arXiv preprint arXiv:1701.04862, 2017.
[6]. Berthelot D, Schumm T, Metz L. Began: Boundary equilibrium generative adversarial networks[J]. arXiv preprint arXiv:1703.10717, 2017.
本文來自雲棲社區合作夥伴“PaperWeekly”,了解相關信息可以關注“PaperWeekly”微信公眾號
最後更新:2017-11-16 15:05:25