
作者 | CDA數(shù)據(jù)分析師
生成對(duì)抗網(wǎng)絡(luò)(GAN)是一種用于訓(xùn)練深度卷積模型以生成合成圖像的體系結(jié)構(gòu)。
盡管非常有效,但默認(rèn)GAN無(wú)法控制生成的圖像類(lèi)型。信息最大化GAN(簡(jiǎn)稱(chēng)InfoGAN)是GAN架構(gòu)的擴(kuò)展,它引入了架構(gòu)自動(dòng)學(xué)習(xí)的控制變量,并允許控制生成的圖像,例如在生成圖像樣式的情況下,厚度和類(lèi)型手寫(xiě)的數(shù)字。
在本教程中,您將了解如何從頭開(kāi)始實(shí)現(xiàn)信息最大化生成對(duì)抗網(wǎng)絡(luò)模型。
完成本教程后,您將了解:
讓我們開(kāi)始吧。
教程概述
本教程分為四個(gè)部分; 他們是:
什么是最大化GAN的信息
Generative Adversarial Network(簡(jiǎn)稱(chēng)GAN)是一種用于訓(xùn)練生成模型的體系結(jié)構(gòu),例如用于生成合成圖像的模型。
它涉及同時(shí)訓(xùn)練生成器模型以生成具有鑒別器模型的圖像,該模型學(xué)習(xí)將圖像分類(lèi)為真實(shí)的(來(lái)自訓(xùn)練數(shù)據(jù)集)或假的(生成的)。這兩個(gè)模型在零和游戲中競(jìng)爭(zhēng),使得訓(xùn)練過(guò)程的收斂涉及在生成器生成令人信服的圖像的技能與能夠檢測(cè)它們的鑒別器之間找到平衡。
生成器模型將來(lái)自潛在空間的隨機(jī)點(diǎn)作為輸入,通常為50到100個(gè)隨機(jī)高斯變量。生成器通過(guò)訓(xùn)練對(duì)潛在空間中的點(diǎn)應(yīng)用獨(dú)特的含義,并將點(diǎn)映射到特定的輸出合成圖像。這意味著雖然潛在空間由生成器模型構(gòu)成,但是無(wú)法控制生成的圖像。
GAN公式使用簡(jiǎn)單的因子連續(xù)輸入噪聲向量z,而對(duì)發(fā)生器可以使用該噪聲的方式?jīng)]有限制。結(jié)果,發(fā)生器可能以高度糾纏的方式使用噪聲,導(dǎo)致z的各個(gè)維度不對(duì)應(yīng)于數(shù)據(jù)的語(yǔ)義特征。
可以探索潛在空間并比較生成的圖像,以試圖理解生成器模型已經(jīng)學(xué)習(xí)的映射?;蛘撸梢岳缤ㄟ^(guò)類(lèi)標(biāo)簽來(lái)調(diào)節(jié)生成過(guò)程,以便可以按需創(chuàng)建特定類(lèi)型的圖像。這是條件生成對(duì)抗網(wǎng)絡(luò)的基礎(chǔ),簡(jiǎn)稱(chēng)CGAN或cGAN。
另一種方法是提供控制變量作為發(fā)電機(jī)的輸入,以及潛在空間中的點(diǎn)(噪聲)。可以訓(xùn)練發(fā)生器以使用控制變量來(lái)影響所生成圖像的特定屬性。這是信息最大化生成對(duì)抗網(wǎng)絡(luò)(簡(jiǎn)稱(chēng)InfoGAN)所采用的方法。
InfoGAN,生成對(duì)抗網(wǎng)絡(luò)的信息理論擴(kuò)展,能夠以完全無(wú)監(jiān)督的方式學(xué)習(xí)解纏結(jié)的表示。
在訓(xùn)練過(guò)程中由發(fā)生器學(xué)習(xí)的結(jié)構(gòu)化映射有些隨機(jī)。雖然生成器模型學(xué)習(xí)在潛在空間中空間分離生成圖像的屬性,但是沒(méi)有控制。這些屬性糾纏在一起。InfoGAN的動(dòng)機(jī)是希望解開(kāi)生成圖像的屬性。
例如,在面部的情況下,可以解開(kāi)和控制生成面部的特性,例如面部的形狀,頭發(fā)顏色,發(fā)型等。
例如,對(duì)于面部的數(shù)據(jù)集,有用的解開(kāi)的表示可以為以下屬性中的每一個(gè)分配一組單獨(dú)的維度:面部表情,眼睛顏色,發(fā)型,眼鏡的存在或不存在,以及相應(yīng)人的身份。
控制變量與噪聲一起提供作為發(fā)電機(jī)的輸入,并且通過(guò)互信息丟失功能訓(xùn)練模型。
......我們對(duì)生成對(duì)抗性網(wǎng)絡(luò)目標(biāo)進(jìn)行了簡(jiǎn)單的修改,鼓勵(lì)它學(xué)習(xí)可解釋和有意義的表達(dá)。我們通過(guò)最大化GAN噪聲變量的固定小子集與觀測(cè)值之間的互信息來(lái)實(shí)現(xiàn)這一點(diǎn),結(jié)果證明是相對(duì)簡(jiǎn)單的。
相互信息是指在給定另一個(gè)變量的情況下獲得的關(guān)于一個(gè)變量的信息量。在這種情況下,我們感興趣的是有關(guān)使用噪聲和控制變量生成的圖像的控制變量的信息。
在信息論中,X和Y之間的互信息I(X; Y)測(cè)量從隨機(jī)變量Y的知識(shí)中學(xué)習(xí)的關(guān)于另一個(gè)隨機(jī)變量X 的“ 信息量 ”。
相互信息(MI)被計(jì)算為圖像的條件熵(由發(fā)生器(G)從噪聲(z)和控制變量(c)創(chuàng)建),給定控制變量(c)從邊際熵減去控制變量(c); 例如:
在實(shí)踐中,計(jì)算真實(shí)的互信息通常是難以處理的,盡管本文采用了簡(jiǎn)化,稱(chēng)為變分信息最大化,并且控制代碼的熵保持不變。
通過(guò)使用稱(chēng)為Q或輔助模型的新模型來(lái)實(shí)現(xiàn)通過(guò)互信息訓(xùn)練發(fā)電機(jī)。新模型與用于解釋輸入圖像的鑒別器模型共享所有相同的權(quán)重,但與預(yù)測(cè)圖像是真實(shí)還是假的鑒別器模型不同,輔助模型預(yù)測(cè)用于生成圖像的控制代碼。
兩種模型都用于更新生成器模型,首先是為了提高生成愚弄鑒別器模型的圖像的可能性,其次是改善用于生成圖像的控制代碼和輔助模型對(duì)控制代碼的預(yù)測(cè)之間的互信息。
結(jié)果是生成器模型通過(guò)互信息丟失而正規(guī)化,使得控制代碼捕獲所生成圖像的顯著特性,并且反過(guò)來(lái)可以用于控制圖像生成過(guò)程。
每當(dāng)我們有興趣學(xué)習(xí)從給定輸入X到保留關(guān)于原始輸入的信息的更高級(jí)別表示Y的參數(shù)化映射時(shí),可以利用互信息。[...]表明,最大化互信息的任務(wù)基本上等同于訓(xùn)練自動(dòng)編碼器以最小化重建誤差。
如何實(shí)現(xiàn)InfoGAN丟失功能
一旦熟悉模型的輸入和輸出,InfoGAN就可以相當(dāng)直接地實(shí)現(xiàn)。
唯一的絆腳石可能是互信息丟失功能,特別是如果你沒(méi)有像大多數(shù)開(kāi)發(fā)人員那樣強(qiáng)大的數(shù)學(xué)背景。
InfoGan使用兩種主要類(lèi)型的控制變量:分類(lèi)和連續(xù),連續(xù)變量可能具有不同的數(shù)據(jù)分布,這會(huì)影響相互損失的計(jì)算方式。可以基于變量類(lèi)型計(jì)算所有控制變量的相互損失并將其相加,這是OpenAI為T(mén)ensorFlow發(fā)布的InfoGAN實(shí)現(xiàn)中使用的方法。
在Keras中,將控制變量簡(jiǎn)化為分類(lèi)和高斯或均勻連續(xù)變量可能更容易,并且對(duì)于每個(gè)控制變量類(lèi)型在輔助模型上具有單獨(dú)的輸出。這樣可以使用不同的損失函數(shù),大大簡(jiǎn)化了實(shí)現(xiàn)。
有關(guān)本節(jié)中建議的更多背景信息,請(qǐng)參閱更多閱讀部分中的文章和帖子。
分類(lèi)控制變量
分類(lèi)變量可用于控制所生成圖像的類(lèi)型或類(lèi)別。
這被實(shí)現(xiàn)為一個(gè)熱編碼矢量。也就是說(shuō),如果類(lèi)具有10個(gè)值,則控制代碼將是一個(gè)類(lèi),例如6,并且輸入到生成器模型的分類(lèi)控制向量將是所有零值的10個(gè)元素向量,其中對(duì)于類(lèi)6具有一個(gè)值,例如,[0,0,0,0,0,0,1,0,0]。
訓(xùn)練模型時(shí),我們不需要選擇分類(lèi)控制變量; 相反,它們是隨機(jī)生成的,例如,每個(gè)樣本以均勻的概率選擇每個(gè)樣本。
...關(guān)于潛碼c~Cat(K = 10,p = 0.1)的統(tǒng)一分類(lèi)分布
在輔助模型中,分類(lèi)變量的輸出層也將是一個(gè)熱編碼矢量以匹配輸入控制代碼,并且使用softmax激活函數(shù)。
對(duì)于分類(lèi)潛在代碼ci,我們使用softmax非線性的自然選擇來(lái)表示Q(ci | x)。
回想一下,互信息被計(jì)算為來(lái)自控制變量的條件熵和從提供給輸入變量的控制變量的熵中減去的輔助模型的輸出。我們可以直接實(shí)現(xiàn)這一點(diǎn),但這不是必需的。
控制變量的熵是一個(gè)常數(shù),并且是一個(gè)接近于零的非常小的數(shù); 因此,我們可以從計(jì)算中刪除它。條件熵可以直接計(jì)算為控制變量輸入和輔助模型的輸出之間的交叉熵。因此,可以使用分類(lèi)交叉熵損失函數(shù),就像我們對(duì)任何多類(lèi)分類(lèi)問(wèn)題一樣。
超參數(shù)lambda用于縮放互信息丟失函數(shù)并設(shè)置為1,因此可以忽略。
即使InfoGAN引入了額外的超參數(shù)λ,它也很容易調(diào)整,簡(jiǎn)單地設(shè)置為1就足以支持離散的潛碼。
連續(xù)控制變量
連續(xù)控制變量可用于控制圖像的樣式。
連續(xù)變量從均勻分布中采樣,例如在-1和1之間,并作為輸入提供給發(fā)電機(jī)模型。
...可以捕捉連續(xù)性變化的連續(xù)代碼:c2,c3~Unif(-1,1)
輔助模型可以用高斯分布實(shí)現(xiàn)連續(xù)控制變量的預(yù)測(cè),其中輸出層被配置為具有一個(gè)節(jié)點(diǎn),平均值和一個(gè)用于高斯標(biāo)準(zhǔn)偏差的節(jié)點(diǎn),例如每個(gè)連續(xù)控制需要兩個(gè)輸出變量。
對(duì)于連續(xù)潛在代碼cj,根據(jù)什么是真正的后驗(yàn)P(cj | x),有更多選項(xiàng)。在我們的實(shí)驗(yàn)中,我們發(fā)現(xiàn)簡(jiǎn)單地將Q(cj | x)視為因式高斯是足夠的。
輸出均值的節(jié)點(diǎn)可以使用線性激活函數(shù),而輸出標(biāo)準(zhǔn)偏差的節(jié)點(diǎn)必須產(chǎn)生正值,因此可以使用諸如sigmoid的激活函數(shù)來(lái)創(chuàng)建0到1之間的值。
對(duì)于連續(xù)潛碼,我們通過(guò)對(duì)角高斯分布對(duì)近似后驗(yàn)進(jìn)行參數(shù)化,識(shí)別網(wǎng)絡(luò)輸出其均值和標(biāo)準(zhǔn)差,其中標(biāo)準(zhǔn)偏差通過(guò)網(wǎng)絡(luò)輸出的指數(shù)變換進(jìn)行參數(shù)化以確保積極性。
必須將損失函數(shù)計(jì)算為高斯控制碼的互信息,這意味著它們必須在計(jì)算損失之前從平均值和標(biāo)準(zhǔn)差重建。計(jì)算高斯分布變量的熵和條件熵可以直接實(shí)現(xiàn),但不是必需的。相反,可以使用均方誤差損失。
或者,可以將輸出分布簡(jiǎn)化為每個(gè)控制變量的均勻分布,可以使用具有線性激活的輔助模型中的每個(gè)變量的單個(gè)輸出節(jié)點(diǎn),并且模型可以使用均方誤差損失函數(shù)。
如何為MNIST開(kāi)發(fā)InfoGAN
在本節(jié)中,我們將仔細(xì)研究生成器(g),鑒別器(d)和輔助模型(q)以及如何在Keras中實(shí)現(xiàn)它們。
我們將為MNIST數(shù)據(jù)集開(kāi)發(fā)InfoGAN實(shí)現(xiàn),如InfoGAN論文中所做的那樣。
本文探討了兩個(gè)版本; 第一個(gè)僅使用分類(lèi)控制代碼,并允許模型將一個(gè)分類(lèi)變量映射到大約一個(gè)數(shù)字(盡管沒(méi)有按分類(lèi)變量排序數(shù)字)。
本文還探討了InfoGAN架構(gòu)的一個(gè)版本,其中包含一個(gè)熱編碼分類(lèi)變量(c1)和兩個(gè)連續(xù)控制變量(c2和c3)。
發(fā)現(xiàn)第一個(gè)連續(xù)變量用于控制數(shù)字的旋轉(zhuǎn),第二個(gè)連續(xù)變量用于控制數(shù)字的粗細(xì)。
我們將重點(diǎn)關(guān)注使用具有10個(gè)值的分類(lèi)控制變量的簡(jiǎn)單情況,并鼓勵(lì)模型學(xué)習(xí)讓該變量控制生成的數(shù)字。您可能希望通過(guò)更改分類(lèi)控制變量的基數(shù)或添加連續(xù)控制變量來(lái)擴(kuò)展此示例。
用于MNIST數(shù)據(jù)集訓(xùn)練的GAN模型的配置作為本文的附錄提供,轉(zhuǎn)載如下。我們將使用列出的配置作為開(kāi)發(fā)我們自己的生成器(g),鑒別器(d)和輔助(q)模型的起點(diǎn)。
讓我們從將生成器模型開(kāi)發(fā)為深度卷積神經(jīng)網(wǎng)絡(luò)(例如DCGAN)開(kāi)始。
該模型可以將噪聲向量(z)和控制向量(c)作為單獨(dú)的輸入,并在將它們用作生成圖像的基礎(chǔ)之前將它們連接起來(lái)?;蛘?,可以預(yù)先將矢量連接起來(lái)并提供給模型中的單個(gè)輸入層。方法是等價(jià)的,在這種情況下我們將使用后者來(lái)保持模型簡(jiǎn)單。
下面的define_generator()函數(shù)定義生成器模型,并將輸入向量的大小作為參數(shù)。
完全連接的層采用輸入向量并產(chǎn)生足夠數(shù)量的激活,以創(chuàng)建512個(gè)7×7特征映射,從中重新激活激活。然后,它們以1×1步幅通過(guò)正常卷積層,然后兩個(gè)隨后的上采樣將卷積層轉(zhuǎn)換為2×2步幅優(yōu)先至14×14特征映射,然后轉(zhuǎn)換為所需的1通道28×28特征映射輸出,其中像素值為通過(guò)tanh激活函數(shù)的范圍[-1,-1]。
良好的發(fā)生器配置啟發(fā)式如下,包括隨機(jī)高斯權(quán)重初始化,隱藏層中的ReLU激活以及批量歸一化的使用。
# define the standalone generator model
def define_generator(gen_input_size):
# weight initialization
init = RandomNormal(stddev=0.02)
# image generator input
in_lat = Input(shape=(gen_input_size,))
# foundation for 7x7 image
n_nodes = 512 * 7 * 7
gen = Dense(n_nodes, kernel_initializer=init)(in_lat)
gen = Activation('relu')(gen)
gen = BatchNormalization()(gen)
gen = Reshape((7, 7, 512))(gen)
# normal
gen = Conv2D(128, (4,4), padding='same', kernel_initializer=init)(gen)
gen = Activation('relu')(gen)
gen = BatchNormalization()(gen)
# upsample to 14x14
gen = Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
gen = Activation('relu')(gen)
gen = BatchNormalization()(gen)
# upsample to 28x28
gen = Conv2DTranspose(1, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
# tanh output
out_layer = Activation('tanh')(gen)
# define model
model = Model(in_lat, out_layer)
return model
接下來(lái),我們可以定義鑒別器和輔助模型。
根據(jù)普通GAN,鑒別器模型以獨(dú)立方式訓(xùn)練在真實(shí)和偽造圖像上。發(fā)電機(jī)和輔助模型都不直接配合; 相反,它們適合作為復(fù)合模型的一部分。
鑒別器和輔助模型共享相同的輸入和特征提取層,但它們的輸出層不同。因此,同時(shí)定義它們是有意義的。
同樣,有許多方法可以實(shí)現(xiàn)這種架構(gòu),但是將鑒別器和輔助模型定義為單獨(dú)的模型首先允許我們稍后通過(guò)功能API直接將它們組合成更大的GAN模型。
下面的define_discriminator()函數(shù)定義了鑒別器和輔助模型,并將分類(lèi)變量的基數(shù)(例如數(shù)值,例如10)作為輸入。輸入圖像的形狀也被參數(shù)化為函數(shù)參數(shù),并設(shè)置為MNIST圖像大小的默認(rèn)值。
特征提取層涉及兩個(gè)下采樣層,而不是池化層作為最佳實(shí)踐。此外,遵循DCGAN模型的最佳實(shí)踐,我們使用LeakyReLU激活和批量標(biāo)準(zhǔn)化。
鑒別器模型(d)具有單個(gè)輸出節(jié)點(diǎn),并通過(guò)S形激活函數(shù)預(yù)測(cè)輸入圖像的實(shí)際概率。該模型被編譯,因?yàn)樗鼘⒁元?dú)立的方式使用,通過(guò)具有最佳實(shí)踐學(xué)習(xí)速率和動(dòng)量的隨機(jī)梯度下降的Adam版本來(lái)優(yōu)化二元交叉熵函數(shù)。
輔助模型(q)對(duì)分類(lèi)變量中的每個(gè)值具有一個(gè)節(jié)點(diǎn)輸出,并使用softmax激活函數(shù)。如InfoGAN論文中所使用的那樣,在特征提取層和輸出層之間添加完全連接的層。該模型未編譯,因?yàn)樗皇仟?dú)立使用或以獨(dú)立方式使用。
# define the standalone discriminator model
def define_discriminator(n_cat, in_shape=(28,28,1)):
# weight initialization
init = RandomNormal(stddev=0.02)
# image input
in_image = Input(shape=in_shape)
# downsample to 14x14
d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
d = LeakyReLU(alpha=0.1)(d)
# downsample to 7x7
d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = LeakyReLU(alpha=0.1)(d)
d = BatchNormalization()(d)
# normal
d = Conv2D(256, (4,4), padding='same', kernel_initializer=init)(d)
d = LeakyReLU(alpha=0.1)(d)
d = BatchNormalization()(d)
# flatten feature maps
d = Flatten()(d)
# real/fake output
out_classifier = Dense(1, activation='sigmoid')(d)
# define d model
d_model = Model(in_image, out_classifier)
# compile d model
d_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
# create q model layers
q = Dense(128)(d)
q = BatchNormalization()(q)
q = LeakyReLU(alpha=0.1)(q)
# q model output
out_codes = Dense(n_cat, activation='softmax')(q)
# define q model
q_model = Model(in_image, out_codes)
return d_model, q_model
接下來(lái),我們可以定義復(fù)合GAN模型。
該模型使用所有子模型,并且是訓(xùn)練發(fā)電機(jī)模型權(quán)重的基礎(chǔ)。
下面的define_gan()函數(shù)實(shí)現(xiàn)了這個(gè)并定義并返回模型,將三個(gè)子模型作為輸入。
如上所述,鑒別器以獨(dú)立方式訓(xùn)練,因此鑒別器的所有權(quán)重被設(shè)置為不可訓(xùn)練(僅在此上下文中)。生成器模型的輸出連接到鑒別器模型的輸入,并連接到輔助模型的輸入。
這將創(chuàng)建一個(gè)新的復(fù)合模型,該模型將[noise + control]向量作為輸入,然后通過(guò)生成器生成圖像。然后,圖像通過(guò)鑒別器模型以產(chǎn)生分類(lèi),并通過(guò)輔助模型產(chǎn)生控制變量的預(yù)測(cè)。
該模型有兩個(gè)輸出層,需要使用不同的損失函數(shù)進(jìn)行訓(xùn)練。二進(jìn)制交叉熵?fù)p失用于鑒別器輸出,正如我們?cè)诰幾g獨(dú)立使用的鑒別器時(shí)所做的那樣,并且互信息丟失用于輔助模型,在這種情況下,輔助模型可以直接實(shí)現(xiàn)為分類(lèi)交叉熵并實(shí)現(xiàn)期望的結(jié)果。
# define the combined discriminator, generator and q network model
def define_gan(g_model, d_model, q_model):
# make weights in the discriminator (some shared with the q model) as not trainable
d_model.trainable = False
# connect g outputs to d inputs
d_output = d_model(g_model.output)
# connect g outputs to q inputs
q_output = q_model(g_model.output)
# define composite model
model = Model(g_model.input, [d_output, q_output])
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], optimizer=opt)
return model
為了使GAN模型架構(gòu)更清晰,我們可以創(chuàng)建模型和復(fù)合模型圖。
下面列出了完整的示例。
# create and plot the infogan model for mnist
from keras.optimizers import Adam
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.layers import Activation
from keras.initializers import RandomNormal
from keras.utils.vis_utils import plot_model
# define the standalone discriminator model
def define_discriminator(n_cat, in_shape=(28,28,1)):
# weight initialization
init = RandomNormal(stddev=0.02)
# image input
in_image = Input(shape=in_shape)
# downsample to 14x14
d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
d = LeakyReLU(alpha=0.1)(d)
# downsample to 7x7
d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = LeakyReLU(alpha=0.1)(d)
d = BatchNormalization()(d)
# normal
d = Conv2D(256, (4,4), padding='same', kernel_initializer=init)(d)
d = LeakyReLU(alpha=0.1)(d)
d = BatchNormalization()(d)
# flatten feature maps
d = Flatten()(d)
# real/fake output
out_classifier = Dense(1, activation='sigmoid')(d)
# define d model
d_model = Model(in_image, out_classifier)
# compile d model
d_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
# create q model layers
q = Dense(128)(d)
q = BatchNormalization()(q)
q = LeakyReLU(alpha=0.1)(q)
# q model output
out_codes = Dense(n_cat, activation='softmax')(q)
# define q model
q_model = Model(in_image, out_codes)
return d_model, q_model
# define the standalone generator model
def define_generator(gen_input_size):
# weight initialization
init = RandomNormal(stddev=0.02)
# image generator input
in_lat = Input(shape=(gen_input_size,))
# foundation for 7x7 image
n_nodes = 512 * 7 * 7
gen = Dense(n_nodes, kernel_initializer=init)(in_lat)
gen = Activation('relu')(gen)
gen = BatchNormalization()(gen)
gen = Reshape((7, 7, 512))(gen)
# normal
gen = Conv2D(128, (4,4), padding='same', kernel_initializer=init)(gen)
gen = Activation('relu')(gen)
gen = BatchNormalization()(gen)
# upsample to 14x14
gen = Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
gen = Activation('relu')(gen)
gen = BatchNormalization()(gen)
# upsample to 28x28
gen = Conv2DTranspose(1, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
# tanh output
out_layer = Activation('tanh')(gen)
# define model
model = Model(in_lat, out_layer)
return model
# define the combined discriminator, generator and q network model
def define_gan(g_model, d_model, q_model):
# make weights in the discriminator (some shared with the q model) as not trainable
d_model.trainable = False
# connect g outputs to d inputs
d_output = d_model(g_model.output)
# connect g outputs to q inputs
q_output = q_model(g_model.output)
# define composite model
model = Model(g_model.input, [d_output, q_output])
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], optimizer=opt)
return model
# number of values for the categorical control code
n_cat = 10
# size of the latent space
latent_dim = 62
# create the discriminator
d_model, q_model = define_discriminator(n_cat)
# create the generator
gen_input_size = latent_dim + n_cat
g_model = define_generator(gen_input_size)
# create the gan
gan_model = define_gan(g_model, d_model, q_model)
# plot the model
plot_model(gan_model, to_file='gan_plot.png', show_shapes=True, show_layer_names=True)
運(yùn)行該示例將創(chuàng)建所有三個(gè)模型,然后創(chuàng)建復(fù)合GAN模型并保存模型體系結(jié)構(gòu)的圖。
注意:創(chuàng)建此圖假設(shè)已安裝pydot和graphviz庫(kù)。如果這是一個(gè)問(wèn)題,您可以注釋掉import語(yǔ)句和對(duì)plot_model()函數(shù)的調(diào)用。
該圖顯示了生成器模型的所有細(xì)節(jié)以及鑒別器和輔助模型的壓縮描述。重要的是,請(qǐng)注意鑒別器輸出的形狀作為預(yù)測(cè)圖像是真實(shí)還是假的單個(gè)節(jié)點(diǎn),以及輔助模型預(yù)測(cè)分類(lèi)控制代碼的10個(gè)節(jié)點(diǎn)。
回想一下,該復(fù)合模型將僅用于更新生成器和輔助模型的模型權(quán)重,并且鑒別器模型中的所有權(quán)重將保持不可約,即僅在更新獨(dú)立鑒別器模型時(shí)更新。
接下來(lái),我們將為發(fā)電機(jī)開(kāi)發(fā)輸入。
每個(gè)輸入都是由噪聲和控制代碼組成的矢量。具體地,高斯隨機(jī)數(shù)的矢量和一個(gè)熱編碼的隨機(jī)選擇的分類(lèi)值。
下面的generatelatentpoints()函數(shù)實(shí)現(xiàn)了這一點(diǎn),將潛在空間的大小,分類(lèi)值的數(shù)量以及要生成的樣本數(shù)作為參數(shù)作為輸入。該函數(shù)返回輸入連接向量作為生成器模型的輸入,以及獨(dú)立控制代碼。通過(guò)復(fù)合GAN模型更新發(fā)電機(jī)和輔助模型時(shí),將需要獨(dú)立控制代碼,專(zhuān)門(mén)用于計(jì)算輔助模型的互信息損失。
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_cat, n_samples):
# generate points in the latent space
z_latent = randn(latent_dim * n_samples)
# reshape into a batch of inputs for the network
z_latent = z_latent.reshape(n_samples, latent_dim)
# generate categorical codes
cat_codes = randint(0, n_cat, n_samples)
# one hot encode
cat_codes = to_categorical(cat_codes, num_classes=n_cat)
# concatenate latent points and control codes
z_input = hstack((z_latent, cat_codes))
return [z_input, cat_codes]
接下來(lái),我們可以生成真實(shí)和虛假的例子。
可以通過(guò)為灰度圖像添加附加維度來(lái)加載MNIST數(shù)據(jù)集,將其轉(zhuǎn)換為3D輸入,并將到范圍[-1,1]以匹配來(lái)自生成器模型的輸出。這是在下面的loadreal
數(shù)據(jù)分析咨詢請(qǐng)掃描二維碼
若不方便掃碼,搜微信號(hào):CDAshujufenxi
LSTM 模型輸入長(zhǎng)度選擇技巧:提升序列建模效能的關(guān)鍵? 在循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)家族中,長(zhǎng)短期記憶網(wǎng)絡(luò)(LSTM)憑借其解決長(zhǎng)序列 ...
2025-07-11CDA 數(shù)據(jù)分析師報(bào)考條件詳解與準(zhǔn)備指南? ? 在數(shù)據(jù)驅(qū)動(dòng)決策的時(shí)代浪潮下,CDA 數(shù)據(jù)分析師認(rèn)證愈發(fā)受到矚目,成為眾多有志投身數(shù) ...
2025-07-11數(shù)據(jù)透視表中兩列相乘合計(jì)的實(shí)用指南? 在數(shù)據(jù)分析的日常工作中,數(shù)據(jù)透視表憑借其強(qiáng)大的數(shù)據(jù)匯總和分析功能,成為了 Excel 用戶 ...
2025-07-11尊敬的考生: 您好! 我們誠(chéng)摯通知您,CDA Level I和 Level II考試大綱將于 2025年7月25日 實(shí)施重大更新。 此次更新旨在確保認(rèn) ...
2025-07-10BI 大數(shù)據(jù)分析師:連接數(shù)據(jù)與業(yè)務(wù)的價(jià)值轉(zhuǎn)化者? ? 在大數(shù)據(jù)與商業(yè)智能(Business Intelligence,簡(jiǎn)稱(chēng) BI)深度融合的時(shí)代,BI ...
2025-07-10SQL 在預(yù)測(cè)分析中的應(yīng)用:從數(shù)據(jù)查詢到趨勢(shì)預(yù)判? ? 在數(shù)據(jù)驅(qū)動(dòng)決策的時(shí)代,預(yù)測(cè)分析作為挖掘數(shù)據(jù)潛在價(jià)值的核心手段,正被廣泛 ...
2025-07-10數(shù)據(jù)查詢結(jié)束后:分析師的收尾工作與價(jià)值深化? ? 在數(shù)據(jù)分析的全流程中,“query end”(查詢結(jié)束)并非工作的終點(diǎn),而是將數(shù) ...
2025-07-10CDA 數(shù)據(jù)分析師考試:從報(bào)考到取證的全攻略? 在數(shù)字經(jīng)濟(jì)蓬勃發(fā)展的今天,數(shù)據(jù)分析師已成為各行業(yè)爭(zhēng)搶的核心人才,而 CDA(Certi ...
2025-07-09【CDA干貨】單樣本趨勢(shì)性檢驗(yàn):捕捉數(shù)據(jù)背后的時(shí)間軌跡? 在數(shù)據(jù)分析的版圖中,單樣本趨勢(shì)性檢驗(yàn)如同一位耐心的偵探,專(zhuān)注于從單 ...
2025-07-09year_month數(shù)據(jù)類(lèi)型:時(shí)間維度的精準(zhǔn)切片? ? 在數(shù)據(jù)的世界里,時(shí)間是最不可或缺的維度之一,而year_month數(shù)據(jù)類(lèi)型就像一把精準(zhǔn) ...
2025-07-09CDA 備考干貨:Python 在數(shù)據(jù)分析中的核心應(yīng)用與實(shí)戰(zhàn)技巧? ? 在 CDA 數(shù)據(jù)分析師認(rèn)證考試中,Python 作為數(shù)據(jù)處理與分析的核心 ...
2025-07-08SPSS 中的 Mann-Kendall 檢驗(yàn):數(shù)據(jù)趨勢(shì)與突變分析的有力工具? ? ? 在數(shù)據(jù)分析的廣袤領(lǐng)域中,準(zhǔn)確捕捉數(shù)據(jù)的趨勢(shì)變化以及識(shí)別 ...
2025-07-08備戰(zhàn) CDA 數(shù)據(jù)分析師考試:需要多久?如何規(guī)劃? CDA(Certified Data Analyst)數(shù)據(jù)分析師認(rèn)證作為國(guó)內(nèi)權(quán)威的數(shù)據(jù)分析能力認(rèn)證 ...
2025-07-08LSTM 輸出不確定的成因、影響與應(yīng)對(duì)策略? 長(zhǎng)短期記憶網(wǎng)絡(luò)(LSTM)作為循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)的一種變體,憑借獨(dú)特的門(mén)控機(jī)制,在 ...
2025-07-07統(tǒng)計(jì)學(xué)方法在市場(chǎng)調(diào)研數(shù)據(jù)中的深度應(yīng)用? 市場(chǎng)調(diào)研是企業(yè)洞察市場(chǎng)動(dòng)態(tài)、了解消費(fèi)者需求的重要途徑,而統(tǒng)計(jì)學(xué)方法則是市場(chǎng)調(diào)研數(shù) ...
2025-07-07CDA數(shù)據(jù)分析師證書(shū)考試全攻略? 在數(shù)字化浪潮席卷全球的當(dāng)下,數(shù)據(jù)已成為企業(yè)決策、行業(yè)發(fā)展的核心驅(qū)動(dòng)力,數(shù)據(jù)分析師也因此成為 ...
2025-07-07剖析 CDA 數(shù)據(jù)分析師考試題型:解鎖高效備考與答題策略? CDA(Certified Data Analyst)數(shù)據(jù)分析師考試作為衡量數(shù)據(jù)專(zhuān)業(yè)能力的 ...
2025-07-04SQL Server 字符串截取轉(zhuǎn)日期:解鎖數(shù)據(jù)處理的關(guān)鍵技能? 在數(shù)據(jù)處理與分析工作中,數(shù)據(jù)格式的規(guī)范性是保證后續(xù)分析準(zhǔn)確性的基礎(chǔ) ...
2025-07-04CDA 數(shù)據(jù)分析師視角:從數(shù)據(jù)迷霧中探尋商業(yè)真相? 在數(shù)字化浪潮席卷全球的今天,數(shù)據(jù)已成為企業(yè)決策的核心驅(qū)動(dòng)力,CDA(Certifie ...
2025-07-04CDA 數(shù)據(jù)分析師:開(kāi)啟數(shù)據(jù)職業(yè)發(fā)展新征程? ? 在數(shù)據(jù)成為核心生產(chǎn)要素的今天,數(shù)據(jù)分析師的職業(yè)價(jià)值愈發(fā)凸顯。CDA(Certified D ...
2025-07-03