<-Design Procedure     Go to ToC      Filter Specification->

1. Introduction

This chapter specifies The steps required to implement the pruning method that enables a User to produce an Up-sampling network with reduced complexity, starting from a fine-tuned model produced using the procedure specified in Design Procedure.

The Pruning Algorithm specified below in natural language and pseudo-code is applied to the Fine-Tuned or Pre-trained Model (in the following “Model”).

The Pruning Algorithm also requires a Pruning Dataset; this can be the one used in the Design Procedure.

2. Pruning Algorithm in natural language

  1. Set the value of the target Performance Criterion.
  2. Set the Pruning Step Size as the percentage of parameters to be removed at each pruning iteration
  3. Compute the Dependency Graph and Pruning Groups.
  4. Starting from the non-pruned network, execute a set of Steps numbered from 1 to N (maximum number of Steps) until the Pruned Model satisfies the Performance Criterion or exceeds the Maximum Pruning Ratio:
    1. Set the Iteration Pruning Target by adding the Pruning Step Size.
    2. Apply Sparsity Learning to the Model.
    3. Evaluate the Importance of each Channel in each Layer of the Model.
    4. Remove the Channels until the Pruning Target is met, starting from the one with the lowest Importance.
    5. For a predefined number of epochs E:
      1. Train the pruned Model over the training dataset for 1 epoch.
      2. Evaluate the MSE of the retrained model.
      3. If the MSE is the smallest of those achieved so far, save the current model.
    6. Select the Model with the smallest MSE.
  5. Save the current Pruned Model.

3. Pruning Algorithm in pseudo code

# dependecy graph computation 

graph = dependecy_graph(MODEL) 

original_error = calc_error_mse(MODEL, DATASET.val) 

current_pruning_target = PRUNING_STEP_SIZE 

 
# starting pruning process 
while (model_accuracy > original_error * PERFORMANCE_ CRITERION) OR (model.param/MODEL.param > MAX_PRUNING_RATE): 
    current_pruning_target = current_pruning_target + PRUNING_STEP_SIZE 

    # model pruning 
    model = sparsity pruning(model, 
    DATASET.train, GROWING_REG_EPOCHS) 

    # remove the weights which have lowest Norm2 valu 
    model = channel_prune_norm2(model, current_pruning_target, graph) 
    model_tmp = model 
    best_error = calc_error_mse(model, DATASET.val)  
 
# retrain the pruned model with best validation error 
    for e in range(RETRAIN_EPOCHS): 
        model_tmp = train_one_epoch(model, DATASET.train) 
        curr_error = calc_error(model_tmp, DATASET.val) 
        if best_error < curr_error: 
            model = model_tmp 
            best_error = curr_error 
 
# Output 

model_output = model 

The Reference Software of the Pruning Algorithm will be available at the MPAI Git. 

<-Design Procedure     Go to ToC      Filter Specification->