神经网络剪枝初探


剪枝

image-20231030111206019

如上图,我们发现函数的曲线其实已经有点过拟合了,但是如果我们去掉一些参数很小的项,那么它的泛化性就更好:

image-20231030111316278

即剪枝的目的:去掉一些不重要的参数,让我们的模型变得更小,计算效率更高。

如何剪枝

image-20231030111543398

例如我们要将前一个网络剪枝成后一个网络,有以下步骤:

  1. 用矩阵表示系数,例如上图中最后一层全连接层,输入是4个,输出是3个,那么我们就可以用一个4 * 3的矩阵进行表示。
  2. 得到系数矩阵后,我们通过计算可以得出哪些系数是接近于0的,生成一个掩码矩阵——接近0的系数在此矩阵中置为0,否则为1。
  3. 再使用系数矩阵点乘掩码矩阵(不是矩阵乘法),就得到了一个新的系数矩阵,也就是剪枝后的系数矩阵。

综上所述,我们的目标就是:如何求出一个掩码矩阵?

numpy为我们提供了一个简单的方法:

image-20231030192019969

即先求出系数矩阵的绝对值,然后告诉该函数要减掉百分之多少的枝,最好函数返回一个阈值:即小于该阈值的数都被剪掉了。

代码实现

  1. 设置掩码,同时考虑掩码是否为空的情况:

    def to_var(x, requires_grad=False):
        """
        Automatically choose cpu or cuda
        """
        if torch.cuda.is_available():
            x = x.cuda()
        return x.clone().detach().requires_grad_(requires_grad)
    
    
    class MaskedLinear(nn.Linear):
        def __init__(self, in_features, out_features, bias=True):
            super(MaskedLinear, self).__init__(in_features, out_features, bias)
            self.register_buffer('mask', None)
    
        def set_mask(self, mask):
            self.mask = to_var(mask, requires_grad=False)
            self.weight.data = self.weight.data * self.mask.data
    
        def get_mask(self):
            return self.mask
    
        def forward(self, x):
            if self.mask is not None:
                weight = self.weight * self.mask
                return F.linear(x, weight, self.bias)
            else:
                return F.linear(x, self.weight, self.bias)
  2. 关键点在于掩码矩阵如何计算:

    def weight_prune(model, pruning_perc):
        '''
        Prune pruning_perc % weights layer-wise
        '''
        threshold_list = [] #  存储每一层的阈值
        for p in model.parameters(): # 遍历模型所有参数
            if len(p.data.size()) != 1: # 如果参数为维度不为1,即不是偏置项
                weight = p.cpu().data.abs().numpy().flatten() # 将参数的绝对值转换为numpy数组,并展平
                threshold = np.percentile(weight, pruning_perc) # 计算阈值
                threshold_list.append(threshold) # 保存每一层的阈值
    
        # generate mask
        masks = [] # 存储每一层的掩码
        idx = 0 # 索引
        for p in model.parameters():
            if len(p.data.size()) != 1:
                pruned_inds = p.data.abs() > threshold_list[idx] # 根据之前计算的阈值,将参数的绝对值和阈值进行比较,得到一个布尔类型的掩码
                masks.append(pruned_inds.float()) # 将掩码转换为浮点型,添加到masks列表中
                idx += 1 # 索引向前
        return masks

    当然上述代码只是线性剪枝,即MLP,其他类型的以后再说。


文章作者: QT-7274
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 QT-7274 !
评论
  目录