You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
22 lines
664 B
Python
22 lines
664 B
Python
9 months ago
|
import paddleslim
|
||
|
import paddle
|
||
|
import numpy as np
|
||
|
|
||
|
from paddleslim.dygraph import FPGMFilterPruner
|
||
|
|
||
|
|
||
|
def prune_model(model, input_shape, prune_ratio=0.1):
|
||
|
flops = paddle.flops(model, input_shape)
|
||
|
pruner = FPGMFilterPruner(model, input_shape)
|
||
|
|
||
|
params_sensitive = {}
|
||
|
for param in model.parameters():
|
||
|
if "transpose" not in param.name and "linear" not in param.name:
|
||
|
# set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
|
||
|
params_sensitive[param.name] = prune_ratio
|
||
|
|
||
|
plan = pruner.prune_vars(params_sensitive, [0])
|
||
|
|
||
|
flops = paddle.flops(model, input_shape)
|
||
|
return model
|