
文章來源:DeepHub IMBA
作者: P**nHub兄弟網(wǎng)站
學(xué)習(xí)如何通過剪枝來使你的模型變得更小
剪枝是一種模型優(yōu)化技術(shù),這種技術(shù)可以消除權(quán)重張量中不必要的值。這將會得到更小的模型,并且模型精度非常接近標(biāo)準(zhǔn)模型。
在本文中,我們將通過一個例子來觀察剪枝技術(shù)對最終模型大小和預(yù)測誤差的影響。
我們的第一步導(dǎo)入一些工具、包:
最后,初始化TensorBoard,這樣就可以將模型可視化:
import os import zipfile import tensorflow as tf import tensorflow_model_optimization as tfmot from tensorflow.keras.models import load_model from tensorflow import keras %load_ext tensorboard
在這個實驗中,我們將使用scikit-learn生成一個回歸數(shù)據(jù)集。之后,我們將數(shù)據(jù)集分解為訓(xùn)練集和測試集:
from sklearn.datasets import make_friedman1 X, y = make_friedman1(n_samples=10000, n_features=10, random_state=0) from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
我們將創(chuàng)建一個簡單的神經(jīng)網(wǎng)絡(luò)來預(yù)測目標(biāo)變量y,然后檢查均值平方誤差。在此之后,我們將把它與修剪過的整個模型進(jìn)行比較,然后只與修剪過的Dense層進(jìn)行比較。
接下來,在30個訓(xùn)練輪次之后,一旦模型停止改進(jìn),我們就使用回調(diào)來停止訓(xùn)練它。
early_stop = keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=30)
我們打印出模型概述,以便與運用剪枝技術(shù)的模型概述進(jìn)行比較。
model = setup_model() model.summary()
讓我們編譯模型并訓(xùn)練它。
tf.keras.utils.plot_model( model, to_file=”model.png”, show_shapes=True, show_layer_names=True, rankdir=”TB”, expand_nested=True, dpi=96, )
現(xiàn)在檢查一下均方誤差。我們可以繼續(xù)到下一節(jié),看看當(dāng)我們修剪整個模型時,這個誤差是如何變化的。
from sklearn.metrics import mean_squared_error predictions = model.predict(X_test) print(‘Without Pruning MSE %.4f’ % mean_squared_error(y_test,predictions.reshape(3300,))) Without Pruning MSE 0.0201
當(dāng)把模型部署到資源受限的邊緣設(shè)備(如手機)時,剪枝等優(yōu)化模型技術(shù)尤其重要。
我們將上面的MSE與修剪整個模型得到的MSE進(jìn)行比較。第一步是定義剪枝參數(shù)。權(quán)重剪枝是基于數(shù)量級的。這意味著在訓(xùn)練過程中一些權(quán)重被轉(zhuǎn)換為零。模型變得稀疏,這樣就更容易壓縮。由于可以跳過零,稀疏模型還可以加快推理速度。
預(yù)期的參數(shù)是剪枝計劃、塊大小和塊池類型。
from tensorflow_model_optimization.sparsity.keras import ConstantSparsity pruning_params = { 'pruning_schedule': ConstantSparsity(0.5, 0), 'block_size': (1, 1), 'block_pooling_type': 'AVG' }
現(xiàn)在,我們可以應(yīng)用我們的剪枝參數(shù)來修剪整個模型。
from tensorflow_model_optimization.sparsity.keras import prune_low_magnitude model_to_prune = prune_low_magnitude( keras.Sequential([ tf.keras.layers.Dense(128, activation='relu', input_shape=(X_train.shape[1],)), tf.keras.layers.Dense(1, activation='relu') ]), **pruning_params)
我們檢查模型概述。將其與未剪枝模型的模型進(jìn)行比較。從下圖中我們可以看到整個模型已經(jīng)被剪枝 —— 我們將很快看到剪枝一個稠密層后模型概述的區(qū)別。
model_to_prune.summary()
在TF中,我們必須先編譯模型,然后才能將其用于訓(xùn)練集和測試集。
model_to_prune.compile(optimizer=’adam’, loss=tf.keras.losses.mean_squared_error, metrics=[‘mae’, ‘mse’])
由于我們正在使用剪枝技術(shù),所以除了早期停止回調(diào)函數(shù)之外,我們還必須定義兩個剪枝回調(diào)函數(shù)。我們定義一個記錄模型的文件夾,然后創(chuàng)建一個帶有回調(diào)函數(shù)的列表。
tfmot.sparsity.keras.UpdatePruningStep()
使用優(yōu)化器步驟更新剪枝包裝器。如果未能指定剪枝包裝器,將會導(dǎo)致錯誤。
tfmot.sparsity.keras.PruningSummaries()
將剪枝概述添加到Tensorboard。
log_dir = ‘.models’ callbacks = [ tfmot.sparsity.keras.UpdatePruningStep(), # Log sparsity and other metrics in Tensorboard. tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir), keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=10) ]
有了這些,我們現(xiàn)在就可以將模型與訓(xùn)練集相匹配了。
model_to_prune.fit(X_train,y_train,epochs=100,validation_split=0.2,callbacks=callbacks,verbose=0)
在檢查這個模型的均方誤差時,我們注意到它比未剪枝模型的均方誤差略高。
prune_predictions = model_to_prune.predict(X_test) print(‘Whole Model Pruned MSE %.4f’ % mean_squared_error(y_test,prune_predictions.reshape(3300,))) Whole Model Pruned MSE 0.1830
現(xiàn)在讓我們實現(xiàn)相同的模型,但這一次,我們將只剪枝稠密層。請注意在剪枝計劃中使用多項式衰退函數(shù)。
from tensorflow_model_optimization.sparsity.keras import PolynomialDecay layer_pruning_params = { 'pruning_schedule': PolynomialDecay(initial_sparsity=0.2, final_sparsity=0.8, begin_step=1000, end_step=2000), 'block_size': (2, 3), 'block_pooling_type': 'MAX' } model_layer_prunning = keras.Sequential([ prune_low_magnitude(tf.keras.layers.Dense(128, activation='relu',input_shape=(X_train.shape[1],)), **layer_pruning_params), tf.keras.layers.Dense(1, activation='relu') ])
從概述中我們可以看到只有第一個稠密層將被剪枝。
model_layer_prunning.summary()
然后我們編譯并擬合模型。
model_layer_prunning.compile(optimizer=’adam’, loss=tf.keras.losses.mean_squared_error, metrics=[‘mae’, ‘mse’]) model_layer_prunning.fit(X_train,y_train,epochs=300,validation_split=0.1,callbacks=callbacks,verbose=0)
現(xiàn)在,讓我們檢查均方誤差。
layer_prune_predictions = model_layer_prunning.predict(X_test) print(‘Layer Prunned MSE %.4f’ % mean_squared_error(y_test,layer_prune_predictions.reshape(3300,))) Layer Prunned MSE 0.1388
由于我們使用了不同的剪枝參數(shù),所以我們無法將這里獲得的MSE與之前的MSE進(jìn)行比較。如果您想比較它們,那么請確保剪枝參數(shù)是相同的。在測試時,對于這個特定情況,layer_pruning_params給出的錯誤比pruning_params要低。比較從不同的剪枝參數(shù)獲得的MSE是有用的,這樣你就可以選擇一個不會使模型性能變差的MSE。
現(xiàn)在讓我們比較一下有剪枝和沒有剪枝模型的大小。我們從訓(xùn)練和保存模型權(quán)重開始,以便以后使用。
def train_save_weights(): model = setup_model() model.compile(optimizer='adam', loss=tf.keras.losses.mean_squared_error, metrics=['mae', 'mse']) model.fit(X_train,y_train,epochs=300,validation_split=0.2,callbacks=callbacks,verbose=0) model.save_weights('.models/friedman_model_weights.h5') train_save_weights()
我們將建立我們的基礎(chǔ)模型,并加載保存的權(quán)重。然后我們對整個模型進(jìn)行剪枝。我們編譯、擬合模型,并在Tensorboard上將結(jié)果可視化。
base_model = setup_model() base_model.load_weights('.models/friedman_model_weights.h5') # optional but recommended for model accuracy model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model) model_for_pruning.compile( loss=tf.keras.losses.mean_squared_error, optimizer='adam', metrics=['mae', 'mse'] ) model_for_pruning.fit( X_train, y_train, callbacks=callbacks, epochs=300, validation_split = 0.2, verbose=0 ) %tensorboard --logdir={log_dir}
以下是TensorBoard的剪枝概述的快照。
在TensorBoard上也可以看到其它剪枝模型概述
現(xiàn)在讓我們定義一個計算模型大小函數(shù)
def get_gzipped_model_size(model,mode_name,zip_name): # Returns size of gzipped model, in bytes. model.save(mode_name, include_optimizer=False) with zipfile.ZipFile(zip_name, 'w', compression=zipfile.ZIP_DEFLATED) as f: f.write(mode_name) return os.path.getsize(zip_name)
現(xiàn)在我們定義導(dǎo)出模型,然后計算大小。
對于剪枝過的模型,tfmot.sparsity.keras.strip_pruning()用來恢復(fù)帶有稀疏權(quán)重的原始模型。請注意剝離模型和未剝離模型在尺寸上的差異。
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning,'.models/model_for_pruning.h5','.models/model_for_pruning.zip'))) print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export,'.models/model_for_export.h5','.models/model_for_export.zip')))
Size of gzipped pruned model without stripping: 6101.00 bytes Size of gzipped pruned model with stripping: 5140.00 bytes
對這兩個模型進(jìn)行預(yù)測,我們發(fā)現(xiàn)它們具有相同的均方誤差。
model_for_prunning_predictions = model_for_pruning.predict(X_test) print('Model for Prunning Error %.4f' % mean_squared_error(y_test,model_for_prunning_predictions.reshape(3300,))) model_for_export_predictions = model_for_export.predict(X_test) print('Model for Export Error %.4f' % mean_squared_error(y_test,model_for_export_predictions.reshape(3300,)))
Model for Prunning Error 0.0264 Model for Export Error 0.0264
您可以繼續(xù)測試不同的剪枝計劃如何影響模型的大小。顯然這里的觀察結(jié)果不具有普遍性。也可以嘗試不同的剪枝參數(shù),并了解它們?nèi)绾斡绊懩哪P痛笮?、預(yù)測誤差/精度,這將取決于您要解決的問題。
為了進(jìn)一步優(yōu)化模型,您可以將其量化。如果您想了解更多,請查看下面的回購和參考資料。
作者:Derrick Mwiti
deephub翻譯組:錢三一
數(shù)據(jù)分析咨詢請掃描二維碼
若不方便掃碼,搜微信號:CDAshujufenxi
LSTM 模型輸入長度選擇技巧:提升序列建模效能的關(guān)鍵? 在循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)家族中,長短期記憶網(wǎng)絡(luò)(LSTM)憑借其解決長序列 ...
2025-07-11CDA 數(shù)據(jù)分析師報考條件詳解與準(zhǔn)備指南? ? 在數(shù)據(jù)驅(qū)動決策的時代浪潮下,CDA 數(shù)據(jù)分析師認(rèn)證愈發(fā)受到矚目,成為眾多有志投身數(shù) ...
2025-07-11數(shù)據(jù)透視表中兩列相乘合計的實用指南? 在數(shù)據(jù)分析的日常工作中,數(shù)據(jù)透視表憑借其強大的數(shù)據(jù)匯總和分析功能,成為了 Excel 用戶 ...
2025-07-11尊敬的考生: 您好! 我們誠摯通知您,CDA Level I和 Level II考試大綱將于 2025年7月25日 實施重大更新。 此次更新旨在確保認(rèn) ...
2025-07-10BI 大數(shù)據(jù)分析師:連接數(shù)據(jù)與業(yè)務(wù)的價值轉(zhuǎn)化者? ? 在大數(shù)據(jù)與商業(yè)智能(Business Intelligence,簡稱 BI)深度融合的時代,BI ...
2025-07-10SQL 在預(yù)測分析中的應(yīng)用:從數(shù)據(jù)查詢到趨勢預(yù)判? ? 在數(shù)據(jù)驅(qū)動決策的時代,預(yù)測分析作為挖掘數(shù)據(jù)潛在價值的核心手段,正被廣泛 ...
2025-07-10數(shù)據(jù)查詢結(jié)束后:分析師的收尾工作與價值深化? ? 在數(shù)據(jù)分析的全流程中,“query end”(查詢結(jié)束)并非工作的終點,而是將數(shù) ...
2025-07-10CDA 數(shù)據(jù)分析師考試:從報考到取證的全攻略? 在數(shù)字經(jīng)濟(jì)蓬勃發(fā)展的今天,數(shù)據(jù)分析師已成為各行業(yè)爭搶的核心人才,而 CDA(Certi ...
2025-07-09【CDA干貨】單樣本趨勢性檢驗:捕捉數(shù)據(jù)背后的時間軌跡? 在數(shù)據(jù)分析的版圖中,單樣本趨勢性檢驗如同一位耐心的偵探,專注于從單 ...
2025-07-09year_month數(shù)據(jù)類型:時間維度的精準(zhǔn)切片? ? 在數(shù)據(jù)的世界里,時間是最不可或缺的維度之一,而year_month數(shù)據(jù)類型就像一把精準(zhǔn) ...
2025-07-09CDA 備考干貨:Python 在數(shù)據(jù)分析中的核心應(yīng)用與實戰(zhàn)技巧? ? 在 CDA 數(shù)據(jù)分析師認(rèn)證考試中,Python 作為數(shù)據(jù)處理與分析的核心 ...
2025-07-08SPSS 中的 Mann-Kendall 檢驗:數(shù)據(jù)趨勢與突變分析的有力工具? ? ? 在數(shù)據(jù)分析的廣袤領(lǐng)域中,準(zhǔn)確捕捉數(shù)據(jù)的趨勢變化以及識別 ...
2025-07-08備戰(zhàn) CDA 數(shù)據(jù)分析師考試:需要多久?如何規(guī)劃? CDA(Certified Data Analyst)數(shù)據(jù)分析師認(rèn)證作為國內(nèi)權(quán)威的數(shù)據(jù)分析能力認(rèn)證 ...
2025-07-08LSTM 輸出不確定的成因、影響與應(yīng)對策略? 長短期記憶網(wǎng)絡(luò)(LSTM)作為循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)的一種變體,憑借獨特的門控機制,在 ...
2025-07-07統(tǒng)計學(xué)方法在市場調(diào)研數(shù)據(jù)中的深度應(yīng)用? 市場調(diào)研是企業(yè)洞察市場動態(tài)、了解消費者需求的重要途徑,而統(tǒng)計學(xué)方法則是市場調(diào)研數(shù) ...
2025-07-07CDA數(shù)據(jù)分析師證書考試全攻略? 在數(shù)字化浪潮席卷全球的當(dāng)下,數(shù)據(jù)已成為企業(yè)決策、行業(yè)發(fā)展的核心驅(qū)動力,數(shù)據(jù)分析師也因此成為 ...
2025-07-07剖析 CDA 數(shù)據(jù)分析師考試題型:解鎖高效備考與答題策略? CDA(Certified Data Analyst)數(shù)據(jù)分析師考試作為衡量數(shù)據(jù)專業(yè)能力的 ...
2025-07-04SQL Server 字符串截取轉(zhuǎn)日期:解鎖數(shù)據(jù)處理的關(guān)鍵技能? 在數(shù)據(jù)處理與分析工作中,數(shù)據(jù)格式的規(guī)范性是保證后續(xù)分析準(zhǔn)確性的基礎(chǔ) ...
2025-07-04CDA 數(shù)據(jù)分析師視角:從數(shù)據(jù)迷霧中探尋商業(yè)真相? 在數(shù)字化浪潮席卷全球的今天,數(shù)據(jù)已成為企業(yè)決策的核心驅(qū)動力,CDA(Certifie ...
2025-07-04CDA 數(shù)據(jù)分析師:開啟數(shù)據(jù)職業(yè)發(fā)展新征程? ? 在數(shù)據(jù)成為核心生產(chǎn)要素的今天,數(shù)據(jù)分析師的職業(yè)價值愈發(fā)凸顯。CDA(Certified D ...
2025-07-03