99999久久久久久亚洲,欧美人与禽猛交狂配,高清日韩av在线影院,一个人在线高清免费观看,啦啦啦在线视频免费观看www

熱線電話:13121318867

登錄
首頁大數(shù)據(jù)時代在tensorFlow中使用模型剪枝將機器學(xué)習(xí)模型變得更小
在tensorFlow中使用模型剪枝將機器學(xué)習(xí)模型變得更小
2020-08-10
收藏

文章來源:DeepHub IMBA

作者: P**nHub兄弟網(wǎng)站

學(xué)習(xí)如何通過剪枝來使你的模型變得更小

剪枝是一種模型優(yōu)化技術(shù),這種技術(shù)可以消除權(quán)重張量中不必要的值。這將會得到更小的模型,并且模型精度非常接近標(biāo)準(zhǔn)模型。

在本文中,我們將通過一個例子來觀察剪枝技術(shù)對最終模型大小和預(yù)測誤差的影響。

導(dǎo)入常見問題

我們的第一步導(dǎo)入一些工具、包:

  • Os和Zi    pfile可以幫助我們評估模型的大小。
  • tensorflow_model_optimization用來修剪模型。
  • load_model用于加載保存的模型。
  • 當(dāng)然還有tensorflow和keras。

最后,初始化TensorBoard,這樣就可以將模型可視化:

import os
import zipfile
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow.keras.models import load_model
from tensorflow import keras
%load_ext tensorboard

數(shù)據(jù)集生成

在這個實驗中,我們將使用scikit-learn生成一個回歸數(shù)據(jù)集。之后,我們將數(shù)據(jù)集分解為訓(xùn)練集和測試集:

from sklearn.datasets import make_friedman1
X, y = make_friedman1(n_samples=10000, n_features=10, random_state=0)

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

沒有應(yīng)用剪枝技術(shù)的模型

我們將創(chuàng)建一個簡單的神經(jīng)網(wǎng)絡(luò)來預(yù)測目標(biāo)變量y,然后檢查均值平方誤差。在此之后,我們將把它與修剪過的整個模型進(jìn)行比較,然后只與修剪過的Dense層進(jìn)行比較。

接下來,在30個訓(xùn)練輪次之后,一旦模型停止改進(jìn),我們就使用回調(diào)來停止訓(xùn)練它。

early_stop = keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=30)

我們打印出模型概述,以便與運用剪枝技術(shù)的模型概述進(jìn)行比較。

model = setup_model()

model.summary()

讓我們編譯模型并訓(xùn)練它。

tf.keras.utils.plot_model(
model,
to_file=”model.png”,
show_shapes=True,
show_layer_names=True,
rankdir=”TB”,
expand_nested=True,
dpi=96,
)

現(xiàn)在檢查一下均方誤差。我們可以繼續(xù)到下一節(jié),看看當(dāng)我們修剪整個模型時,這個誤差是如何變化的。

from sklearn.metrics import mean_squared_error

predictions = model.predict(X_test)

print(‘Without Pruning MSE %.4f’ %
mean_squared_error(y_test,predictions.reshape(3300,)))

Without Pruning MSE 0.0201

當(dāng)把模型部署到資源受限的邊緣設(shè)備(如手機)時,剪枝等優(yōu)化模型技術(shù)尤其重要。

采用等稀疏修剪對整個模型進(jìn)行剪枝

我們將上面的MSE與修剪整個模型得到的MSE進(jìn)行比較。第一步是定義剪枝參數(shù)。權(quán)重剪枝是基于數(shù)量級的。這意味著在訓(xùn)練過程中一些權(quán)重被轉(zhuǎn)換為零。模型變得稀疏,這樣就更容易壓縮。由于可以跳過零,稀疏模型還可以加快推理速度。

預(yù)期的參數(shù)是剪枝計劃、塊大小和塊池類型。

  • 在本例中,我們設(shè)置了50%的稀疏度,這意味著50%的權(quán)重將歸零。
  • block_size —— 矩陣權(quán)重張量中塊稀疏模式的維度(高度,權(quán)值)。
  • block_pooling_type —— 用于對塊中的權(quán)重進(jìn)行池化的函數(shù)。必須是AVG或MAX。

   

from tensorflow_model_optimization.sparsity.keras import ConstantSparsity
pruning_params = {
  'pruning_schedule': ConstantSparsity(0.5, 0),
  'block_size': (1, 1),
  'block_pooling_type': 'AVG'
}

現(xiàn)在,我們可以應(yīng)用我們的剪枝參數(shù)來修剪整個模型。

from tensorflow_model_optimization.sparsity.keras import prune_low_magnitude
model_to_prune = prune_low_magnitude(
keras.Sequential([
  tf.keras.layers.Dense(128, activation='relu', input_shape=(X_train.shape[1],)),
  tf.keras.layers.Dense(1, activation='relu')
]), **pruning_params)

我們檢查模型概述。將其與未剪枝模型的模型進(jìn)行比較。從下圖中我們可以看到整個模型已經(jīng)被剪枝 —— 我們將很快看到剪枝一個稠密層后模型概述的區(qū)別。

model_to_prune.summary()

在TF中,我們必須先編譯模型,然后才能將其用于訓(xùn)練集和測試集。

model_to_prune.compile(optimizer=’adam’,
loss=tf.keras.losses.mean_squared_error,
metrics=[‘mae’, ‘mse’])

由于我們正在使用剪枝技術(shù),所以除了早期停止回調(diào)函數(shù)之外,我們還必須定義兩個剪枝回調(diào)函數(shù)。我們定義一個記錄模型的文件夾,然后創(chuàng)建一個帶有回調(diào)函數(shù)的列表。

tfmot.sparsity.keras.UpdatePruningStep()

使用優(yōu)化器步驟更新剪枝包裝器。如果未能指定剪枝包裝器,將會導(dǎo)致錯誤。

tfmot.sparsity.keras.PruningSummaries()

將剪枝概述添加到Tensorboard。

log_dir = ‘.models’
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
# Log sparsity and other metrics in Tensorboard.
tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir),
keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=10)
]

有了這些,我們現(xiàn)在就可以將模型與訓(xùn)練集相匹配了。

model_to_prune.fit(X_train,y_train,epochs=100,validation_split=0.2,callbacks=callbacks,verbose=0)

在檢查這個模型的均方誤差時,我們注意到它比未剪枝模型的均方誤差略高。

prune_predictions = model_to_prune.predict(X_test)

print(‘Whole Model Pruned MSE %.4f’ %

mean_squared_error(y_test,prune_predictions.reshape(3300,)))

Whole Model Pruned MSE 0.1830

用多項式剪枝計劃對稠密層進(jìn)行剪枝

現(xiàn)在讓我們實現(xiàn)相同的模型,但這一次,我們將只剪枝稠密層。請注意在剪枝計劃中使用多項式衰退函數(shù)。

from tensorflow_model_optimization.sparsity.keras import PolynomialDecay
layer_pruning_params = {
  'pruning_schedule': PolynomialDecay(initial_sparsity=0.2,
  final_sparsity=0.8, begin_step=1000, end_step=2000),
  'block_size': (2, 3),
  'block_pooling_type': 'MAX'
}

model_layer_prunning = keras.Sequential([
prune_low_magnitude(tf.keras.layers.Dense(128, activation='relu',input_shape=(X_train.shape[1],)),
**layer_pruning_params),
    tf.keras.layers.Dense(1, activation='relu')
  ])

從概述中我們可以看到只有第一個稠密層將被剪枝。

model_layer_prunning.summary()

然后我們編譯并擬合模型。

model_layer_prunning.compile(optimizer=’adam’,
loss=tf.keras.losses.mean_squared_error,
metrics=[‘mae’, ‘mse’])
model_layer_prunning.fit(X_train,y_train,epochs=300,validation_split=0.1,callbacks=callbacks,verbose=0)

現(xiàn)在,讓我們檢查均方誤差。

layer_prune_predictions = model_layer_prunning.predict(X_test)

print(‘Layer Prunned MSE %.4f’ %
mean_squared_error(y_test,layer_prune_predictions.reshape(3300,)))

Layer Prunned MSE 0.1388

由于我們使用了不同的剪枝參數(shù),所以我們無法將這里獲得的MSE與之前的MSE進(jìn)行比較。如果您想比較它們,那么請確保剪枝參數(shù)是相同的。在測試時,對于這個特定情況,layer_pruning_params給出的錯誤比pruning_params要低。比較從不同的剪枝參數(shù)獲得的MSE是有用的,這樣你就可以選擇一個不會使模型性能變差的MSE。

比較模型大小

現(xiàn)在讓我們比較一下有剪枝和沒有剪枝模型的大小。我們從訓(xùn)練和保存模型權(quán)重開始,以便以后使用。

def train_save_weights():
  model = setup_model()
  model.compile(optimizer='adam',
    loss=tf.keras.losses.mean_squared_error,
    metrics=['mae', 'mse'])
  model.fit(X_train,y_train,epochs=300,validation_split=0.2,callbacks=callbacks,verbose=0)
  model.save_weights('.models/friedman_model_weights.h5')
 
train_save_weights()

我們將建立我們的基礎(chǔ)模型,并加載保存的權(quán)重。然后我們對整個模型進(jìn)行剪枝。我們編譯、擬合模型,并在Tensorboard上將結(jié)果可視化。

base_model = setup_model()
base_model.load_weights('.models/friedman_model_weights.h5') # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

model_for_pruning.compile(
  loss=tf.keras.losses.mean_squared_error,
  optimizer='adam',
  metrics=['mae', 'mse']
)

model_for_pruning.fit(
  X_train,
  y_train,
  callbacks=callbacks,
  epochs=300,
  validation_split = 0.2,
  verbose=0
)

%tensorboard --logdir={log_dir}

以下是TensorBoard的剪枝概述的快照。

在TensorBoard上也可以看到其它剪枝模型概述

現(xiàn)在讓我們定義一個計算模型大小函數(shù)

def get_gzipped_model_size(model,mode_name,zip_name):
  # Returns size of gzipped model, in bytes.
   
  model.save(mode_name, include_optimizer=False)
   
  with zipfile.ZipFile(zip_name, 'w', compression=zipfile.ZIP_DEFLATED) as f:
  f.write(mode_name)
   
  return os.path.getsize(zip_name)

現(xiàn)在我們定義導(dǎo)出模型,然后計算大小。

對于剪枝過的模型,tfmot.sparsity.keras.strip_pruning()用來恢復(fù)帶有稀疏權(quán)重的原始模型。請注意剝離模型和未剝離模型在尺寸上的差異。

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning,'.models/model_for_pruning.h5','.models/model_for_pruning.zip')))

print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export,'.models/model_for_export.h5','.models/model_for_export.zip')))
Size of gzipped pruned model without stripping: 6101.00 bytes
Size of gzipped pruned model with stripping: 5140.00 bytes

對這兩個模型進(jìn)行預(yù)測,我們發(fā)現(xiàn)它們具有相同的均方誤差。

model_for_prunning_predictions = model_for_pruning.predict(X_test)
print('Model for Prunning Error %.4f' % mean_squared_error(y_test,model_for_prunning_predictions.reshape(3300,)))
model_for_export_predictions = model_for_export.predict(X_test)
print('Model for Export Error %.4f' % mean_squared_error(y_test,model_for_export_predictions.reshape(3300,)))
Model for Prunning Error 0.0264
Model for Export Error 0.0264

最終想法

您可以繼續(xù)測試不同的剪枝計劃如何影響模型的大小。顯然這里的觀察結(jié)果不具有普遍性。也可以嘗試不同的剪枝參數(shù),并了解它們?nèi)绾斡绊懩哪P痛笮?、預(yù)測誤差/精度,這將取決于您要解決的問題。

為了進(jìn)一步優(yōu)化模型,您可以將其量化。如果您想了解更多,請查看下面的回購和參考資料。

作者:Derrick Mwiti

deephub翻譯組:錢三一

數(shù)據(jù)分析咨詢請掃描二維碼

若不方便掃碼,搜微信號:CDAshujufenxi

數(shù)據(jù)分析師資訊
更多

OK
客服在線
立即咨詢
客服在線
立即咨詢
') } function initGt() { var handler = function (captchaObj) { captchaObj.appendTo('#captcha'); captchaObj.onReady(function () { $("#wait").hide(); }).onSuccess(function(){ $('.getcheckcode').removeClass('dis'); $('.getcheckcode').trigger('click'); }); window.captchaObj = captchaObj; }; $('#captcha').show(); $.ajax({ url: "/login/gtstart?t=" + (new Date()).getTime(), // 加隨機數(shù)防止緩存 type: "get", dataType: "json", success: function (data) { $('#text').hide(); $('#wait').show(); // 調(diào)用 initGeetest 進(jìn)行初始化 // 參數(shù)1:配置參數(shù) // 參數(shù)2:回調(diào),回調(diào)的第一個參數(shù)驗證碼對象,之后可以使用它調(diào)用相應(yīng)的接口 initGeetest({ // 以下 4 個配置參數(shù)為必須,不能缺少 gt: data.gt, challenge: data.challenge, offline: !data.success, // 表示用戶后臺檢測極驗服務(wù)器是否宕機 new_captcha: data.new_captcha, // 用于宕機時表示是新驗證碼的宕機 product: "float", // 產(chǎn)品形式,包括:float,popup width: "280px", https: true // 更多配置參數(shù)說明請參見:http://docs.geetest.com/install/client/web-front/ }, handler); } }); } function codeCutdown() { if(_wait == 0){ //倒計時完成 $(".getcheckcode").removeClass('dis').html("重新獲取"); }else{ $(".getcheckcode").addClass('dis').html("重新獲取("+_wait+"s)"); _wait--; setTimeout(function () { codeCutdown(); },1000); } } function inputValidate(ele,telInput) { var oInput = ele; var inputVal = oInput.val(); var oType = ele.attr('data-type'); var oEtag = $('#etag').val(); var oErr = oInput.closest('.form_box').next('.err_txt'); var empTxt = '請輸入'+oInput.attr('placeholder')+'!'; var errTxt = '請輸入正確的'+oInput.attr('placeholder')+'!'; var pattern; if(inputVal==""){ if(!telInput){ errFun(oErr,empTxt); } return false; }else { switch (oType){ case 'login_mobile': pattern = /^1[3456789]\d{9}$/; if(inputVal.length==11) { $.ajax({ url: '/login/checkmobile', type: "post", dataType: "json", data: { mobile: inputVal, etag: oEtag, page_ur: window.location.href, page_referer: document.referrer }, success: function (data) { } }); } break; case 'login_yzm': pattern = /^\d{6}$/; break; } if(oType=='login_mobile'){ } if(!!validateFun(pattern,inputVal)){ errFun(oErr,'') if(telInput){ $('.getcheckcode').removeClass('dis'); } }else { if(!telInput) { errFun(oErr, errTxt); }else { $('.getcheckcode').addClass('dis'); } return false; } } return true; } function errFun(obj,msg) { obj.html(msg); if(msg==''){ $('.login_submit').removeClass('dis'); }else { $('.login_submit').addClass('dis'); } } function validateFun(pat,val) { return pat.test(val); }