<-Design Procedure Go to ToC Filter Specification->
1. Introduction
This chapter specifies The steps required to implement the pruning method that enables a User to produce an Up-sampling network with reduced complexity, starting from a fine-tuned model produced using the procedure specified in Design Procedure.
The Pruning Algorithm specified below in natural language and pseudo-code is applied to the Fine-Tuned or Pre-trained Model (in the following “Model”).
The Pruning Algorithm also requires a Pruning Dataset; this can be the one used in the Design Procedure.
2. Pruning Algorithm in natural language
- Set the value of the target Performance Criterion.
- Set the Pruning Step Size as the percentage of parameters to be removed at each pruning iteration
- Compute the Dependency Graph and Pruning Groups.
- Starting from the non-pruned network, execute a set of Steps numbered from 1 to N (maximum number of Steps) until the Pruned Model satisfies the Performance Criterion or exceeds the Maximum Pruning Ratio:
- Set the Iteration Pruning Target by adding the Pruning Step Size.
- Apply Sparsity Learning to the Model.
- Evaluate the Importance of each Channel in each Layer of the Model.
- Remove the Channels until the Pruning Target is met, starting from the one with the lowest Importance.
- For a predefined number of epochs E:
- Train the pruned Model over the training dataset for 1 epoch.
- Evaluate the MSE of the retrained model.
- If the MSE is the smallest of those achieved so far, save the current model.
- Select the Model with the smallest MSE.
- Save the current Pruned Model.
3. Pruning Algorithm in pseudo code
# dependecy graph computation
graph = dependecy_graph(MODEL)
original_error = calc_error_mse(MODEL, DATASET.val)
current_pruning_target = PRUNING_STEP_SIZE
# starting pruning process
while (model_accuracy > original_error * PERFORMANCE_ CRITERION) OR (model.param/MODEL.param > MAX_PRUNING_RATE):
current_pruning_target = current_pruning_target + PRUNING_STEP_SIZE
# model pruning
model = sparsity pruning(model,
DATASET.train, GROWING_REG_EPOCHS)
# remove the weights which have lowest Norm2 valu
model = channel_prune_norm2(model, current_pruning_target, graph)
model_tmp = model
best_error = calc_error_mse(model, DATASET.val)
# retrain the pruned model with best validation error
for e in range(RETRAIN_EPOCHS):
model_tmp = train_one_epoch(model, DATASET.train)
curr_error = calc_error(model_tmp, DATASET.val)
if best_error < curr_error:
model = model_tmp
best_error = curr_error
# Output
model_output = model
The Reference Software of the Pruning Algorithm will be available at the MPAI Git.