894
京東網上商城
深度學習界明星:生成對抗網絡與Improving GAN
2014年,深度學習三巨頭之一IanGoodfellow提出了生成對抗網絡(Generative Adversarial Networks, GANs)這一概念,剛開始並沒有引起轟動,直到2016年,學界、業界對它的興趣如“井噴”一樣爆發,多篇重磅文章陸續發表。2016年12月NIPS大會上,Goodfellow做了關於GANs的專題報告,使得GANs成為了當今最熱門的研究領域之一,本文將介紹如今深度學習界的明星——生成對抗網絡。
1何為生成對抗網絡
生成對抗網絡,根據它的名字,可以推斷這個網絡由兩部分組成:第一部分是生成,第二部分是對抗。這個網絡的第一部分是生成模型,就像之前介紹的自動編碼器的解碼部分;第二部分是對抗模型,嚴格來說它是一個判斷真假圖片的判別器。生成對抗網絡最大的創新在此,這也是生成對抗網絡與自動編碼器最大的區別。簡單來說,生成對抗網絡就是讓兩個網絡相互競爭,通過生成網絡來生成假的數據,對抗網絡通過判別器判別真偽,最後希望生成網絡生成的數據能夠以假亂真騙過判別器。過程如圖1所示。
圖1 生成對抗網絡生成數據過程
下麵依次介紹生成模型和對抗模型。
1. 生成模型
首先看看生成模型,前一節自動編碼器其實已經給出了一般的生成模型。
在生成對抗網絡中,不再是將圖片輸入編碼器得到隱含向量然後生成圖片,而是隨機初始化一個隱含向量,根據變分自動編碼器的特點,初始化一個正態分布的隱含向量,通過類似解碼的過程,將它映射到一個更高的維度,最後生成一個與輸入數據相似的數據,這就是假的圖片。這時自動編碼器是通過對比兩張圖片之間每個像素點的差異計算損失函數的,而生成對抗網絡會通過對抗過程來計算出這個損失函數,如圖2所示。
圖2 生成模型
2. 對抗模型
重點來介紹對抗過程,這個過程是生成對抗網絡相對於之前的生成模型如自動編碼器等最大的創新。
對抗過程簡單來說就是一個判斷真假的判別器,相當於一個二分類問題,輸入一張真的圖片希望判別器輸出的結果是1,輸入一張假的圖片希望判別器輸出的結果是0。
這跟原圖片的label 沒有關係,不管原圖片到底是一個多少類別的圖片,它們都統一稱為真的圖片,輸出的label 是1,則表示是真實的;而生成圖片的label 是0,則表示是假的。
在訓練的時候,先訓練判別器,將假的數據和真的數據都輸入給判別模型,這個時候優化這個判別模型,希望它能夠正確地判斷出真的數據和假的數據,這樣就能夠得到一個比較好的判別器。
然後開始訓練生成器,希望它生成的假的數據能夠騙過現在這個比較好的判別器。
具體做法就是將判別器的參數固定,通過反向傳播優化生成器的參數,希望生成器得到的數據在經過判別器之後得到的結果能盡可能地接近1,這時隻需要調整一下損失函數就可以了,之前在優化判別器的時候損失函數是讓假的數據盡可能接近0,而現在訓練生成器的損失函數是讓假的數據盡可能接近1。
這其實就是一個簡單的二分類問題,這個問題可以用前麵介紹過的很多方法去處理,比如Logistic 回歸、多層感知器、卷積神經網絡、循環神經網絡等。
上麵是生成對抗網絡的簡單解釋,可以通過代碼更清晰地展示整個過程。
跟自動編碼器一樣,先使用簡單的多層感知器來實現:
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.dis = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.dis(x)
return x
上麵是判別器的結構,中間使用了斜率設為0.2 的LeakyReLU 激活函數,最後需要使用nn.Sigmoid() 將結果映射到0 s 1 之間概率進行真假的二分類。這裏之所以用LeakyReLU 激活函數而不使用ReLU 激活函數,是因為經過實驗,LeakyReLU 的表現更好。
class generator(nn.Module):
def __init__(self, input_size):
super(generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(input_size, 256),
nn.ReLU(True),
nn.Linear(256, 256),
nn.ReLU(True),
nn.Linear(256, 784),
nn.Tanh()
)
def forward(self, x):
x = self.gen(x)
return x
這就是生成器的結構,跟自動編碼器中的解碼器是類似的,最後需要使用nn.Tanh(),將數據分布到-1 ~1 之間,這是因為輸入的圖片會規範化到-1 ~1之間。
接著需要定義損失函數和優化函數:
criterion = nn.BCELoss() # Binary Cross Entropy
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
這裏使用二分類的損失函數nn.BCELoss(),使用Adam 優化函數,學習率設置為0.0003。
接著是最為重要的訓練過程,這個過程分為兩個部分:一個是判別器的訓練,一個是生成器的訓練。
首先來看看判別器的訓練。
img = img.view(num_img, -1)
real_img = Variable(img).cuda()
real_label = Variable(torch.ones(num_img)).cuda()
fake_label = Variable(torch.zeros(num_img)).cuda()
# compute loss of real_img
real_out = D(real_img)
d_loss_real = criterion(real_out, real_label)
real_scores = real_out
# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda()
fake_img = G(z)
fake_out = D(fake_img)
d_loss_fake = criterion(fake_out, fake_label)
fake_scores = fake_out
# bp and optimize
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
開始需要自己創建label,真實的數據是1,生成的假的數據是0,然後將真實的數據輸入判別器得到loss,將假的數據輸入判別器得到loss,將這兩個loss 加起來得到總的loss,然後反向傳播去更新參數就能夠得到一個優化好的判別器。
接下來是生成模型的訓練:
# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda() # 得到隨機噪聲
fake_img = G(z) # 生成假的圖片
output = D(fake_img) # 經過判別器得到結果
g_loss = criterion(output, real_label) # 得到假的圖片與真實圖片label的loss
# bp and optimize
g_optimizer.zero_grad() # 歸0梯度
g_loss.backward() # 反向傳播
g_optimizer.step() # 更新生成網絡的參數
一個隨機隱含向量通過生成網絡得到了一個假的數據,然後希望假的數據經過判別模型後盡可能和真實label 接近,通過g_loss = criterion(output, real_label)實現,然後反向傳播去優化生成器的參數,在這個過程中,判別器的參數不再發生變化,否則生成器永遠無法騙過優化的判別器。
除了使用簡單的多層感知器外,也可以在生成模型和對抗模型中使用更加複雜的卷積神經網絡,定義十分簡單。
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28
nn.LeakyReLU(0.2, True),
nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14
nn.LeakyReLU(0.2, True),
nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7
)
self.fc = nn.Sequential(
nn.Linear(6477, 1024),
nn.LeakyReLU(0.2, True),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):
'''
x: batch, width, height, channel=1
'''
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class generator(nn.Module):
def __init__(self, input_size, num_feature):
super(generator, self).__init__()
self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56
self.br = nn.Sequential(
nn.BatchNorm2d(1),
nn.ReLU(True)
)
self.downsample1 = nn.Sequential(
nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56
nn.BatchNorm2d(50),
nn.ReLU(True)
)
self.downsample2 = nn.Sequential(
nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56
nn.BatchNorm2d(25),
nn.ReLU(True)
)
self.downsample3 = nn.Sequential(
nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28
nn.Tanh()
)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 1, 56, 56)
x = self.br(x)
x = self.downsample1(x)
x = self.downsample2(x)
x = self.downsample3(x)
return x
圖3 左邊是多層感知器的生成對抗網絡,右邊是卷積生成對抗網絡,右邊的圖片比左邊的圖片噪聲明顯更少。在卷積神經網絡裏引入了批標準化(Batchnormalization)來穩定訓練,同時使用了LeakyReLU 和平均池化來進行訓練。生成對抗網絡的訓練其實是很困難的,因為這是兩個對偶網絡在相互學習,所以需要增加一些訓練技巧才能使訓練更加穩定。
圖3生成對抗網絡對比結果
以上介紹了生成對抗網絡的簡單原理和訓練流程,但是對生成對抗網絡而言,它其實並沒有真正地學習到它要表示的物體,通過對抗的過程,它隻是生成了一張盡可能真的圖片,這就意味著沒辦法決定用哪種噪聲能夠生成想要的圖片,除非把初始分布都試一遍。所以在生成對抗網絡提出之後,有很多基於標準生成對抗網絡的變式來解決各種各樣的問題。
2 Improving GAN
這一節將介紹改善的生成對抗網絡,因為生成對抗網絡存在很多問題,所以人們研究能否通過改善網絡結構或者損害函數來解決這些問題。
1 Wasserstein GAN
Wasserstein GAN 是GAN 的一種變式,我們知道GAN 的訓練是非常麻煩的,需要很多訓練技巧,而且在不同的數據集上,由於數據的分布會發生變化,也需要重新調整參數,不僅需要小心地平衡生成器和判別器的訓練進程,同時生成的樣本還缺乏多樣性。除此之外最大的問題是沒辦法衡量這個生成器到底好不好,因為沒辦法通過判別器的loss 去判斷這個事情。雖然DC GAN 依靠對生成器和判別器的結構進行枚舉,最終找到了一個比較好的網絡設置,但還是沒有從根本上解決訓練的問題。
WGAN 的出現,徹底解決了下麵這些難點:
(1)徹底解決了訓練不穩定的問題,不再需要設計參數去平衡判別器和生成器;
(2)基本解決了collapse mode 的問題,確保了生成樣本的多樣性;
(3)訓練中有一個向交叉熵、準確率的數值指標來衡量訓練的進程,數值越小代表GAN 訓練得越好,同時也就代表著生成的圖片質量越高;
(4)不需要精心設計網絡結構,用簡單的多層感知器就能夠取得比較好的效果。
下麵先介紹為什麼GAN 會有這些缺點,然後解釋WGAN是通過什麼辦法解決這些問題的。
① GAN 的局限性
從式(1)我們知道原始的GAN 是通過最優判別器下的JS Divergence 來衡量兩種分布之間的差異的,而且最優判別器下JS Divergence 越小,就說明兩種分布越接近,但是JS Divergence 有一個嚴重的問題,那就是如果兩種分布完全沒有重疊部分,或者說重疊部分可忽略,那麼JS Divergence 將恒等於常數log2。換句話說,就算兩種分布很接近,但是隻要它們沒有重疊,那麼JS Divergence 就是一個常數,這就使得網絡沒辦法通過這個損失函數去學習,因為它沒辦法知道它是否做得好,這就會導致梯度消失,同時這也使得我們沒有辦法衡量這兩種分布到底有多靠近。
而真實分布與生成的分布沒有重疊部分的概率有多大呢?其實是非常大的,直觀來講,真實分布是一個高維分布,而生成的分布來自於一個低維分布,所以其實很有可能生成分布和真實分布之間就沒有重疊的部分。除此之外,不可能真正去計算兩個分布,隻能近似取樣,所以也導致了兩種分布沒有重疊部分。如果判別器訓練得太好,那麼生成的分布和原來分布基本沒有重疊部分,這就導致了梯度消失;如果判別器訓練得不好,這樣生成器的梯度又不準,就會出現錯誤的優化方向。如果要使得GAN 能夠完美地收斂,那麼需要判別器的訓練不好也不壞,而這個度是很難把握的,況且這還依賴數據的分布等條件,所以GAN 才這麼難訓練。
②Wasserstein 距離
既然GAN 存在的問題都是由於JS Divergence 引起的,那麼能不能換一種度量方式去衡量兩種分布之間的差異,而不使用JS Divergence?答案是肯定的,這就是WGAN中提出的解決辦法。
首先介紹一種新的度量方式去度量兩種分布之間的差異——Wasserstein 距離,也稱為Earth Mover 距離,定義如下:
看上去可能比較複雜,數學解釋如下:對於兩種分布Pr 和Pg,它們的聯合分布是II(Pr,Pg),換句話說II(Pr,Pg) 中每一個聯合分布的邊緣分布就是Pr 或者Pg。那麼對每一個聯合分布而言,從裏麵取樣x 和y,並計算x 和y 的距離,然後取遍所有的x 和y 計算一下期望,接著取這些期望裏麵最小的作為W 距離的定義。
如果上麵的解釋不夠清楚,也可以通俗地解釋,因為它還有一個別名叫Earth mover距離,也就是推土機距離,這是什麼意思呢?可以把兩種分布想象成兩堆土,然後想想如何用推土機將一種分布變成另外一種分布的樣子,會有很多種移動方案,裏麵最小消耗的那種方案就是最優的方案,也就是這個距離的定義。
W 距離與JS Divergence 相比有什麼好處呢?最大的好處就是不管兩種分布是否有重疊,它都是連續變換的而不是突變的,可以用下麵這個例子來說明一下,如圖4所示。
圖4 W 距離例子
通過上麵這個演示可以發現,雖然兩種分布更接近,但JS Divergence 仍然是log2,W 距離就能夠連續而有效地衡量兩種分布之間的差異。
③WGAN
W 距離有很好的優越性,把它拿來作為兩種分布的度量優化生成器,但是W 距離裏麵有一個是沒辦法求解的。作者Martin 在論文附錄裏麵通過定理將這個問題轉變成了一個新的問題,有著如下形式:
這裏引入了一個新的概念——Lipschitz 連續。如果函數f 滿足Lipschitz 連續條件,那麼它就滿足下麵的式子:
我們不希望函數的變化太快,希望函數f 變化能比較平緩。
也就是說構建一個神經網絡D 作為判別器,希望D 輸出的變化比較平緩,在實際計算中限製D 中的參數大小不超過某個範圍,這樣就使得關於輸入的樣本,D 的輸出變化基本不會超過某個範圍,所以就能夠基本滿足Lipschitz 連續條件。
盡可能取到最大,同時D 還要滿足Lipschitz 連續條件,得到的L 可以近似為真實分布和生成分布的Wasserstein 距離。原始的GAN 做的是二分類的任務,也就是對於真假圖片進行二分類,而WGAN 做的是回歸問題,相當於近似擬合Wasserstein 距離。
最後優化生成器的時候希望最小化L,這時候需要滿足Lipschitz 連續條件,所以需要做權重的裁剪,由於W 距離的優越性,不再需要擔心梯度消失的問題,這樣就能夠得到WGAN 的整個訓練過程。
總結一下,WGAN 與原始GAN 相比,隻改了以下四點:
(1)判別器最後一層去掉sigmoid;
(2)生成器和判別器的loss 不取log;
(3)每次更新判別器的參數之後把它們的絕對值裁剪到不超過一個固定常數的數;
(4)不要用基於動量的優化算法(比如momentuem 和Adam),推薦使用RMSProp。
前三點都是從理論分析得到的結果,第(4)點是作者從實驗中發現的。對於WGAN,論文作者做了不少實驗,得到了幾個結論:第一,WGAN 如果使用類似DCGAN 的結構,那麼和DCGAN 生成的圖片差不多,但是WGAN 的優勢就在於不用DCGAN 的結構,也能生成效果比較好的圖片,但是把DCGAN 的Batch Normalization 拿掉的話,DCGAN 就不能生成圖片了;第二,WGAN 和原始的GAN 都是用多層全連接網絡的話,WGAN 生成的圖片質量會變得差一些,但是原始的GAN 不僅質量很差,還有多樣性不足的問題。
2 Improving WGAN
WGAN 的提出成功地解決了GAN 的很多問題,最後需要滿足一階Lipschitz 連續性條件,所以在訓練的時候加了一個限製——權重裁剪。
然而權重的裁剪隻是一種簡單的做法,不是最好的做法,所以隨後有人提出了一些新的辦法來解決這個問題。
首先提出一個定理:一個可微函數如果滿足1 階Lipschitz 連續,等價於它的梯度範數處小於1。用式子來表示就是:
不需要在整個分布上都滿足Lipschitz 條件,隻需要沿著一些直線上的點滿足這些,結果就已經很好了,同時在實際中采用的策略也不是取max,因為不希望太小,所以做的是最小化
,最後改進的WGAN 就是:
改進後的WGAN 和改進前的WGAN 相比,訓練更加穩定,生成的圖片效果也更好。
以上內容節選自《深度學習入門之PyTorch》,點此鏈接可在博文視點官網查看此書。
想及時獲得更多精彩文章,可在微信中搜索“博文視點”或者掃描下方二維碼並關注。
最後更新:2017-11-13 15:34:11