閱讀730 返回首頁    go 阿裏雲 go 技術社區[雲棲]


深度卷積對抗生成網絡(DCGAN)實戰

更多深度文章,請關注雲計算頻道:https://yq.aliyun.com/cloud


生成式對抗網絡(GANs)的概念在四年前由Ian Goodfellow創造。古德費洛(Goodfellow)認為鑒別器(Discriminator)是藝術評論家而相對的藝術家則是生成器(Generator),它們兩個就組成了GAN。藝術評論家(Discriminator)看著一幅圖像,試圖確定它是真的還是偽造的。一個想欺騙藝術評論家的藝術家(Generator)試圖製造一個看起來盡可能真實的偽造圖像。這兩種模式相互戰鬥” ,鑒別器使用生成器的輸出作為訓練數據,並且生成器也可以從鑒別器獲得反饋。在這個過程中,我們想要的模型正在變得更強大,GAN也可以基於一定數量的已知輸入數據(在這種情況下是圖像)生成新的複雜數據。
創建一個GAN可能聽起來很困難,但,在本教程中,我們將使用TensorFlow來構建一個簡單的能夠生成人臉圖像的GAN
1.深度卷積對抗生成網絡(DCGAN)的架構
在本教程中,我們不是試圖模仿簡單的數字數據,而是我們試圖模仿一個圖像,甚至它可以去欺騙一個人。生成器將隨機生成的噪聲向量作為輸入數據,然後使用稱為反卷積的技術將數據轉換為圖像。
鑒別器是經典的卷積神經網絡,其分類真實和假圖像。

75da164107fd4c943e9677b1ba2fe437c8fac198
我們將使用原來的非深度卷積生成對抗網絡的無監督表示學習DCGAN體係結構它由四個卷積層作為鑒別器,四個解卷積層(反卷積層)作為發生器。
2.創建
GitHub上訪問本教程的代碼和Jupyter Notebook。所有的指令都在GitHub倉庫的README文件中。一個幫手指令將自動為你下載CelebA數據集,讓你快速啟動並運行。在這個過程中一定要安裝matplotlib才能看到真正的圖像和另外一定要下載數據集。如果你不想自己安裝它,存儲庫中將包含一個Docker映像。
3.CelebA數據集
CelebFaces數據集包含超過20萬個名人圖像,每個圖像具有40個屬性注釋。由於我們隻是想生成隨機麵的圖像,所以我們將忽略注釋。而且數據集包括超過10,000個不同的身份,這對我們的需要來說是最佳的。

013abdaca76274d70ebbe5ecdea6f5f48da0c4a9
不過,盡管如此我們也要定義一個批量生成的函數。這個函數將加載我們的圖像,並根據我們稍後將要設置的批量大小給我們一個圖像陣列。為了獲得更好的效果,我們將裁剪圖像,以便隻顯示臉部。我們還將圖像歸一化,使得它們的像素值在-0.5+0.5的範圍內。最後,我們打算將圖像縮小到28x28。這困難會使我們失去了一些圖像質量,但它大大減少了訓練時間。
4.定義網絡輸入
在我們開始定義我們的兩個網絡之前,我們首先要定義我們的輸入。我們這樣做是為了不讓雜亂的訓練過程變得比現在更加混亂。在這裏,我們隻是簡單地定義TensorFlow占位符,用於我們真實和虛假的圖像輸入以及為了保存我們的學習率的值。

    inputs_real = tf.placeholder(tf.float32, shape=(None, image_width, image_height, image_channels), name='input_real')
    inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z')
    learning_rate = tf.placeholder(tf.float32, name='learning_rate')
   
    return inputs_real, inputs_z, learning_rate

TensorFlow分配變量占位符特別容易。在完成這些之後,我們可以通過稍後指定一個Feed字典來使用我們網絡中的占位符。
5.創建鑒別器網絡(The discriminator network
接著,我們來創建我們最重要的網絡。鑒別器是藝術評論家,試圖區分真實和虛假的圖像。簡單地說,這是一個用於圖像分類的卷積神經網絡。如果你已經有了一些深度學習的經驗,那麼你有可能已經建立了一個非常類似於這個網絡的網絡。

17bcc5f446ad8af89373e516e9d4441536e349c9
定義這個網絡時,我們要使用一個TensorFlow變量作用域。這有助於我們稍後的訓練過程,所以我們可以重複使用我們鑒別器和發生器的變量名。

def discriminator(images, reuse=False):
    """
    Create the discriminator network
    """
   
    with tf.variable_scope('discriminator', reuse=reuse):
        # … the model

鑒別器網絡由三個卷積層組成,相對於原始架構中的四個卷積層。我們將刪除最後一層來簡化模型。通過這種方式,訓練會進行得更快,而且不會損失太多的質量。對於網絡中的每一層,我們要進行卷積,然後我們還要進行批標準化,以使網絡更快,更準確,接著,我們要進行Leaky RELU進一步加快訓練。最後,我們將最後一層的輸出變平,並使用sigmoid激活函數來獲得分類。這樣我們就會獲得一個可以預測圖像是否是真實的網絡。
6.發生器網絡(The generator network
發生器是以另一種方式存在於GAN中:試圖欺騙鑒別器的是藝術家。發生器利用去卷積層(deconvolutional),它們與卷積圖層完全相反:除了將圖像轉換為簡單的數值數據(如分類)之外,我們還將執行反卷積以將數字數據轉換為圖像,而不是執行卷積。正如我們在設置鑒別器網絡中所做的那樣,我們也將其設置在一個可變範圍內。

344297c3061f8416e90b8fc5036d9154382b1d73
首先,我們接受我們的輸入,稱為Z,並將其輸入到我們的第一個解卷積層。每個解卷積層執行解卷積,然後執行批量歸一化和 Leaky ReLu。然後,我們返回tanh激活函數。
注意:先訓練!
在我們真正開始訓練過程之前,我們需要做一些其他的事情。首先,我們需要定義所有幫助我們計算損失的變量。其次,我們需要定義我們的優化功能。最後,我們將建立一個小函數來輸出生成的圖像,然後訓練網絡。
7.損失函數
我們需要定義三個損失函數,而不是僅具有單個損失函數:發生器的損失函數,使用真實圖像時鑒別器的損失函數,以及使用假圖像時鑒別器的損失函數。假圖像和真實圖像損失的總和理應是整體鑒別器損失。
首先,我們先定義我們對真實圖像的損失。為此,我們在處理真實圖像時要傳遞鑒別器的輸出,並將其與標簽全部進行比較。我們在這裏使用一種稱為標簽平滑的技術,通過將0.9乘以1來幫助我們的網絡更加準確。
然後,我們為我們的假圖像定義損失。這次我們在處理偽造的圖像時將鑒別器的輸出傳遞給我們的標簽,如果這些標簽都是0,這意味著它們是假的。
最後,對於發生器定義損失器。
8.優化和可視化
在優化的步驟中,我們正在尋找所有可以通過使用tf.trainable_variables函數進行訓練的變量。既然我們之前使用了變量作用域,我們可以非常舒適地檢索這些變量。然後我們使用Adam優化器來幫助我們減少損失。

def model_opt(d_loss, g_loss, learning_rate, beta1):
    """
    Get optimization operations
    """
    t_vars = tf.trainable_variables()
    d_vars = [var for var in t_vars if var.name.startswith('discriminator')]
    g_vars = [var for var in t_vars if var.name.startswith('generator')]

在我們準備的最後一步中,我們將編寫一個小段程序,使用matplotlib庫在筆記本上顯示生成的圖像。
9.訓練
我們正在進行我們的最後一步!現在,我們隻獲取我們之前定義的輸入,損失和優化器,調用一個TensorFlow會話並運行批處理。每400一個批次,我們通過顯示生成的圖像和生成器以及鑒別器的損失來輸出當前的進度。現在向後看,看到臉部緩緩而穩定地出現。根據你的設置,此進度可能需要一個小時或更長時間。

ae7ce03cb62aa22b8e351b85517e315487e46820
10.結論
恭喜你!你現在知道GAN做什麼,甚至知道如何使用它們生成人臉圖像。這隻是GAN的冰山一角,GAN還有很多其他的應用。
例如:密歇根大學和德國馬克斯普朗克研究所的研究人員使用GAN從文本中生成圖像。根據論文描述,他們能夠產生非常真實的花鳥。這可以擴展到一些其他領域,比如警察素描或者平麵設計。伯克利的研究人員也設法創建了一個GAN,增強了模煳的圖像,甚至重建了損壞的圖像數據。
總之,GAN是非常強大的。

作者: Dominic Monn 

Dominic Monn目前是瑞士蘇黎世NVIDIA公司的深度實習生

本文由北郵@愛可可-愛生老師推薦,阿裏雲雲棲社組織翻譯。

文章原標題《Deep convolutional generative adversarial networks with TensorFlow》,作者:Dominic Monn譯者:虎說八道,審閱:

文章為簡譯,更為詳細的內容,請查看原文

最後更新:2017-11-06 10:33:57

  上一篇:go  MaxCompute 2.0: Evolution of Alibaba's Big Data Service
  下一篇:go  的點點滴滴多多多多多多