
文章來源:DeepHub IMBA
作者: P**nHub兄弟網(wǎng)站
學(xué)習(xí)如何通過剪枝來使你的模型變得更小
剪枝是一種模型優(yōu)化技術(shù),這種技術(shù)可以消除權(quán)重張量中不必要的值。這將會得到更小的模型,并且模型精度非常接近標(biāo)準(zhǔn)模型。
在本文中,我們將通過一個(gè)例子來觀察剪枝技術(shù)對最終模型大小和預(yù)測誤差的影響。
我們的第一步導(dǎo)入一些工具、包:
最后,初始化TensorBoard,這樣就可以將模型可視化:
import os import zipfile import tensorflow as tf import tensorflow_model_optimization as tfmot from tensorflow.keras.models import load_model from tensorflow import keras %load_ext tensorboard
在這個(gè)實(shí)驗(yàn)中,我們將使用scikit-learn生成一個(gè)回歸數(shù)據(jù)集。之后,我們將數(shù)據(jù)集分解為訓(xùn)練集和測試集:
from sklearn.datasets import make_friedman1 X, y = make_friedman1(n_samples=10000, n_features=10, random_state=0) from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
我們將創(chuàng)建一個(gè)簡單的神經(jīng)網(wǎng)絡(luò)來預(yù)測目標(biāo)變量y,然后檢查均值平方誤差。在此之后,我們將把它與修剪過的整個(gè)模型進(jìn)行比較,然后只與修剪過的Dense層進(jìn)行比較。
接下來,在30個(gè)訓(xùn)練輪次之后,一旦模型停止改進(jìn),我們就使用回調(diào)來停止訓(xùn)練它。
early_stop = keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=30)
我們打印出模型概述,以便與運(yùn)用剪枝技術(shù)的模型概述進(jìn)行比較。
model = setup_model() model.summary()
讓我們編譯模型并訓(xùn)練它。
tf.keras.utils.plot_model( model, to_file=”model.png”, show_shapes=True, show_layer_names=True, rankdir=”TB”, expand_nested=True, dpi=96, )
現(xiàn)在檢查一下均方誤差。我們可以繼續(xù)到下一節(jié),看看當(dāng)我們修剪整個(gè)模型時(shí),這個(gè)誤差是如何變化的。
from sklearn.metrics import mean_squared_error predictions = model.predict(X_test) print(‘Without Pruning MSE %.4f’ % mean_squared_error(y_test,predictions.reshape(3300,))) Without Pruning MSE 0.0201
當(dāng)把模型部署到資源受限的邊緣設(shè)備(如手機(jī))時(shí),剪枝等優(yōu)化模型技術(shù)尤其重要。
我們將上面的MSE與修剪整個(gè)模型得到的MSE進(jìn)行比較。第一步是定義剪枝參數(shù)。權(quán)重剪枝是基于數(shù)量級的。這意味著在訓(xùn)練過程中一些權(quán)重被轉(zhuǎn)換為零。模型變得稀疏,這樣就更容易壓縮。由于可以跳過零,稀疏模型還可以加快推理速度。
預(yù)期的參數(shù)是剪枝計(jì)劃、塊大小和塊池類型。
from tensorflow_model_optimization.sparsity.keras import ConstantSparsity pruning_params = { 'pruning_schedule': ConstantSparsity(0.5, 0), 'block_size': (1, 1), 'block_pooling_type': 'AVG' }
現(xiàn)在,我們可以應(yīng)用我們的剪枝參數(shù)來修剪整個(gè)模型。
from tensorflow_model_optimization.sparsity.keras import prune_low_magnitude model_to_prune = prune_low_magnitude( keras.Sequential([ tf.keras.layers.Dense(128, activation='relu', input_shape=(X_train.shape[1],)), tf.keras.layers.Dense(1, activation='relu') ]), **pruning_params)
我們檢查模型概述。將其與未剪枝模型的模型進(jìn)行比較。從下圖中我們可以看到整個(gè)模型已經(jīng)被剪枝 —— 我們將很快看到剪枝一個(gè)稠密層后模型概述的區(qū)別。
model_to_prune.summary()
在TF中,我們必須先編譯模型,然后才能將其用于訓(xùn)練集和測試集。
model_to_prune.compile(optimizer=’adam’, loss=tf.keras.losses.mean_squared_error, metrics=[‘mae’, ‘mse’])
由于我們正在使用剪枝技術(shù),所以除了早期停止回調(diào)函數(shù)之外,我們還必須定義兩個(gè)剪枝回調(diào)函數(shù)。我們定義一個(gè)記錄模型的文件夾,然后創(chuàng)建一個(gè)帶有回調(diào)函數(shù)的列表。
tfmot.sparsity.keras.UpdatePruningStep()
使用優(yōu)化器步驟更新剪枝包裝器。如果未能指定剪枝包裝器,將會導(dǎo)致錯(cuò)誤。
tfmot.sparsity.keras.PruningSummaries()
將剪枝概述添加到Tensorboard。
log_dir = ‘.models’ callbacks = [ tfmot.sparsity.keras.UpdatePruningStep(), # Log sparsity and other metrics in Tensorboard. tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir), keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=10) ]
有了這些,我們現(xiàn)在就可以將模型與訓(xùn)練集相匹配了。
model_to_prune.fit(X_train,y_train,epochs=100,validation_split=0.2,callbacks=callbacks,verbose=0)
在檢查這個(gè)模型的均方誤差時(shí),我們注意到它比未剪枝模型的均方誤差略高。
prune_predictions = model_to_prune.predict(X_test) print(‘Whole Model Pruned MSE %.4f’ % mean_squared_error(y_test,prune_predictions.reshape(3300,))) Whole Model Pruned MSE 0.1830
現(xiàn)在讓我們實(shí)現(xiàn)相同的模型,但這一次,我們將只剪枝稠密層。請注意在剪枝計(jì)劃中使用多項(xiàng)式衰退函數(shù)。
from tensorflow_model_optimization.sparsity.keras import PolynomialDecay layer_pruning_params = { 'pruning_schedule': PolynomialDecay(initial_sparsity=0.2, final_sparsity=0.8, begin_step=1000, end_step=2000), 'block_size': (2, 3), 'block_pooling_type': 'MAX' } model_layer_prunning = keras.Sequential([ prune_low_magnitude(tf.keras.layers.Dense(128, activation='relu',input_shape=(X_train.shape[1],)), **layer_pruning_params), tf.keras.layers.Dense(1, activation='relu') ])
從概述中我們可以看到只有第一個(gè)稠密層將被剪枝。
model_layer_prunning.summary()
然后我們編譯并擬合模型。
model_layer_prunning.compile(optimizer=’adam’, loss=tf.keras.losses.mean_squared_error, metrics=[‘mae’, ‘mse’]) model_layer_prunning.fit(X_train,y_train,epochs=300,validation_split=0.1,callbacks=callbacks,verbose=0)
現(xiàn)在,讓我們檢查均方誤差。
layer_prune_predictions = model_layer_prunning.predict(X_test) print(‘Layer Prunned MSE %.4f’ % mean_squared_error(y_test,layer_prune_predictions.reshape(3300,))) Layer Prunned MSE 0.1388
由于我們使用了不同的剪枝參數(shù),所以我們無法將這里獲得的MSE與之前的MSE進(jìn)行比較。如果您想比較它們,那么請確保剪枝參數(shù)是相同的。在測試時(shí),對于這個(gè)特定情況,layer_pruning_params給出的錯(cuò)誤比pruning_params要低。比較從不同的剪枝參數(shù)獲得的MSE是有用的,這樣你就可以選擇一個(gè)不會使模型性能變差的MSE。
現(xiàn)在讓我們比較一下有剪枝和沒有剪枝模型的大小。我們從訓(xùn)練和保存模型權(quán)重開始,以便以后使用。
def train_save_weights(): model = setup_model() model.compile(optimizer='adam', loss=tf.keras.losses.mean_squared_error, metrics=['mae', 'mse']) model.fit(X_train,y_train,epochs=300,validation_split=0.2,callbacks=callbacks,verbose=0) model.save_weights('.models/friedman_model_weights.h5') train_save_weights()
我們將建立我們的基礎(chǔ)模型,并加載保存的權(quán)重。然后我們對整個(gè)模型進(jìn)行剪枝。我們編譯、擬合模型,并在Tensorboard上將結(jié)果可視化。
base_model = setup_model() base_model.load_weights('.models/friedman_model_weights.h5') # optional but recommended for model accuracy model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model) model_for_pruning.compile( loss=tf.keras.losses.mean_squared_error, optimizer='adam', metrics=['mae', 'mse'] ) model_for_pruning.fit( X_train, y_train, callbacks=callbacks, epochs=300, validation_split = 0.2, verbose=0 ) %tensorboard --logdir={log_dir}
以下是TensorBoard的剪枝概述的快照。
在TensorBoard上也可以看到其它剪枝模型概述
現(xiàn)在讓我們定義一個(gè)計(jì)算模型大小函數(shù)
def get_gzipped_model_size(model,mode_name,zip_name): # Returns size of gzipped model, in bytes. model.save(mode_name, include_optimizer=False) with zipfile.ZipFile(zip_name, 'w', compression=zipfile.ZIP_DEFLATED) as f: f.write(mode_name) return os.path.getsize(zip_name)
現(xiàn)在我們定義導(dǎo)出模型,然后計(jì)算大小。
對于剪枝過的模型,tfmot.sparsity.keras.strip_pruning()用來恢復(fù)帶有稀疏權(quán)重的原始模型。請注意剝離模型和未剝離模型在尺寸上的差異。
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning,'.models/model_for_pruning.h5','.models/model_for_pruning.zip'))) print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export,'.models/model_for_export.h5','.models/model_for_export.zip')))
Size of gzipped pruned model without stripping: 6101.00 bytes Size of gzipped pruned model with stripping: 5140.00 bytes
對這兩個(gè)模型進(jìn)行預(yù)測,我們發(fā)現(xiàn)它們具有相同的均方誤差。
model_for_prunning_predictions = model_for_pruning.predict(X_test) print('Model for Prunning Error %.4f' % mean_squared_error(y_test,model_for_prunning_predictions.reshape(3300,))) model_for_export_predictions = model_for_export.predict(X_test) print('Model for Export Error %.4f' % mean_squared_error(y_test,model_for_export_predictions.reshape(3300,)))
Model for Prunning Error 0.0264 Model for Export Error 0.0264
您可以繼續(xù)測試不同的剪枝計(jì)劃如何影響模型的大小。顯然這里的觀察結(jié)果不具有普遍性。也可以嘗試不同的剪枝參數(shù),并了解它們?nèi)绾斡绊懩哪P痛笮?、預(yù)測誤差/精度,這將取決于您要解決的問題。
為了進(jìn)一步優(yōu)化模型,您可以將其量化。如果您想了解更多,請查看下面的回購和參考資料。
作者:Derrick Mwiti
deephub翻譯組:錢三一
數(shù)據(jù)分析咨詢請掃描二維碼
若不方便掃碼,搜微信號:CDAshujufenxi
SQL Server 中 CONVERT 函數(shù)的日期轉(zhuǎn)換:從基礎(chǔ)用法到實(shí)戰(zhàn)優(yōu)化 在 SQL Server 的數(shù)據(jù)處理中,日期格式轉(zhuǎn)換是高頻需求 —— 無論 ...
2025-09-18MySQL 大表拆分與關(guān)聯(lián)查詢效率:打破 “拆分必慢” 的認(rèn)知誤區(qū) 在 MySQL 數(shù)據(jù)庫管理中,“大表” 始終是性能優(yōu)化繞不開的話題。 ...
2025-09-18CDA 數(shù)據(jù)分析師:表結(jié)構(gòu)數(shù)據(jù) “獲取 - 加工 - 使用” 全流程的賦能者 表結(jié)構(gòu)數(shù)據(jù)(如數(shù)據(jù)庫表、Excel 表、CSV 文件)是企業(yè)數(shù)字 ...
2025-09-18DSGE 模型中的 Et:理性預(yù)期算子的內(nèi)涵、作用與應(yīng)用解析 動(dòng)態(tài)隨機(jī)一般均衡(Dynamic Stochastic General Equilibrium, DSGE)模 ...
2025-09-17Python 提取 TIF 中地名的完整指南 一、先明確:TIF 中的地名有哪兩種存在形式? 在開始提取前,需先判斷 TIF 文件的類型 —— ...
2025-09-17CDA 數(shù)據(jù)分析師:解鎖表結(jié)構(gòu)數(shù)據(jù)特征價(jià)值的專業(yè)核心 表結(jié)構(gòu)數(shù)據(jù)(以 “行 - 列” 規(guī)范存儲的結(jié)構(gòu)化數(shù)據(jù),如數(shù)據(jù)庫表、Excel 表、 ...
2025-09-17Excel 導(dǎo)入數(shù)據(jù)含缺失值?詳解 dropna 函數(shù)的功能與實(shí)戰(zhàn)應(yīng)用 在用 Python(如 pandas 庫)處理 Excel 數(shù)據(jù)時(shí),“缺失值” 是高頻 ...
2025-09-16深入解析卡方檢驗(yàn)與 t 檢驗(yàn):差異、適用場景與實(shí)踐應(yīng)用 在數(shù)據(jù)分析與統(tǒng)計(jì)學(xué)領(lǐng)域,假設(shè)檢驗(yàn)是驗(yàn)證研究假設(shè)、判斷數(shù)據(jù)差異是否 “ ...
2025-09-16CDA 數(shù)據(jù)分析師:掌控表格結(jié)構(gòu)數(shù)據(jù)全功能周期的專業(yè)操盤手 表格結(jié)構(gòu)數(shù)據(jù)(以 “行 - 列” 存儲的結(jié)構(gòu)化數(shù)據(jù),如 Excel 表、數(shù)據(jù) ...
2025-09-16MySQL 執(zhí)行計(jì)劃中 rows 數(shù)量的準(zhǔn)確性解析:原理、影響因素與優(yōu)化 在 MySQL SQL 調(diào)優(yōu)中,EXPLAIN執(zhí)行計(jì)劃是核心工具,而其中的row ...
2025-09-15解析 Python 中 Response 對象的 text 與 content:區(qū)別、場景與實(shí)踐指南 在 Python 進(jìn)行 HTTP 網(wǎng)絡(luò)請求開發(fā)時(shí)(如使用requests ...
2025-09-15CDA 數(shù)據(jù)分析師:激活表格結(jié)構(gòu)數(shù)據(jù)價(jià)值的核心操盤手 表格結(jié)構(gòu)數(shù)據(jù)(如 Excel 表格、數(shù)據(jù)庫表)是企業(yè)最基礎(chǔ)、最核心的數(shù)據(jù)形態(tài) ...
2025-09-15Python HTTP 請求工具對比:urllib.request 與 requests 的核心差異與選擇指南 在 Python 處理 HTTP 請求(如接口調(diào)用、數(shù)據(jù)爬取 ...
2025-09-12解決 pd.read_csv 讀取長浮點(diǎn)數(shù)據(jù)的科學(xué)計(jì)數(shù)法問題 為幫助 Python 數(shù)據(jù)從業(yè)者解決pd.read_csv讀取長浮點(diǎn)數(shù)據(jù)時(shí)的科學(xué)計(jì)數(shù)法問題 ...
2025-09-12CDA 數(shù)據(jù)分析師:業(yè)務(wù)數(shù)據(jù)分析步驟的落地者與價(jià)值優(yōu)化者 業(yè)務(wù)數(shù)據(jù)分析是企業(yè)解決日常運(yùn)營問題、提升執(zhí)行效率的核心手段,其價(jià)值 ...
2025-09-12用 SQL 驗(yàn)證業(yè)務(wù)邏輯:從規(guī)則拆解到數(shù)據(jù)把關(guān)的實(shí)戰(zhàn)指南 在業(yè)務(wù)系統(tǒng)落地過程中,“業(yè)務(wù)邏輯” 是連接 “需求設(shè)計(jì)” 與 “用戶體驗(yàn) ...
2025-09-11塔吉特百貨孕婦營銷案例:數(shù)據(jù)驅(qū)動(dòng)下的精準(zhǔn)零售革命與啟示 在零售行業(yè) “流量紅利見頂” 的當(dāng)下,精準(zhǔn)營銷成為企業(yè)突圍的核心方 ...
2025-09-11CDA 數(shù)據(jù)分析師與戰(zhàn)略 / 業(yè)務(wù)數(shù)據(jù)分析:概念辨析與協(xié)同價(jià)值 在數(shù)據(jù)驅(qū)動(dòng)決策的體系中,“戰(zhàn)略數(shù)據(jù)分析”“業(yè)務(wù)數(shù)據(jù)分析” 是企業(yè) ...
2025-09-11Excel 數(shù)據(jù)聚類分析:從操作實(shí)踐到業(yè)務(wù)價(jià)值挖掘 在數(shù)據(jù)分析場景中,聚類分析作為 “無監(jiān)督分組” 的核心工具,能從雜亂數(shù)據(jù)中挖 ...
2025-09-10統(tǒng)計(jì)模型的核心目的:從數(shù)據(jù)解讀到?jīng)Q策支撐的價(jià)值導(dǎo)向 統(tǒng)計(jì)模型作為數(shù)據(jù)分析的核心工具,并非簡單的 “公式堆砌”,而是圍繞特定 ...
2025-09-10