Pytorch中获取模型摘要的3种方法

2年前 (2022) 程序员胖胖胖虎阿
351 0 0

在pytorch中获取模型的可训练和不可训练的参数,层名称,内核大小和数量。

Pytorch nn.Module 类中没有提供像与Keras那样的可以计算模型中可训练和不可训练的参数的数量并显示模型摘要的方法 。所以在这篇文章中,我将总结我知道三种方法来计算Pytorch模型中可训练和不可训练的参数的数量。

直接手写代码

最直接的办法就是我们自己手写代码代码实现这个功能,所以这里我自己实现了一个函数,函数中为了漂亮所以引入了PrettyTable的包

 from prettytable import PrettyTable
 
 def count_parameters(model):
     table = PrettyTable([“Modules”, “Parameters”])
     total_params = 0
     for name, parameter in model.named_parameters():
         if not parameter.requires_grad: continue
         params = parameter.numel()
         table.add_row([name, params])
         total_params+=params
     print(table)
     print(f”Total Trainable Params: {total_params}”)
     return total_params

我们拿RESNET18为例,以上函数的输出如下:

 +------------------------------+------------+ 
 |           Modules            | Parameters | 
 +------------------------------+------------+ 
 |         conv1.weight         |    9408    | 
 |          bn1.weight          |     64     | 
 |           bn1.bias           |     64     | 
 |    layer1.0.conv1.weight     |   36864    | 
 |     layer1.0.bn1.weight      |     64     | 
 |      layer1.0.bn1.bias       |     64     |
 .
 .
 .
 |          fc.weight           |   512000   | 
 |           fc.bias            |    1000    | 
 +------------------------------+------------+ 
 Total Trainable Params: 11689512

输出以参数为单位,可以看到模型中存在的每个参数的可训练参数,是不是和keras的基本一样。

torchsummary

torchsummary出现的时候的目标就是为了让torch有类似keras一样的打印模型参数的功能,它非常友好并且十分简单。当前版本为1.5.1,可以直接使用pip安装:

 pip install torchsummary

安装完成后即可使用,我们还是以resnet18为例

 from torchsummary import summary
 model = torchvision.models.resnet18().cuda()

在使用时,我们需要生成一个模型的输入变量,也就是模拟模型的前向传播的过程:

 summary(model, input_size = (3, 64, 64), batch_size = -1)

结果如下:

 — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — 
 Layer (type)               Output Shape                  Param # ================================================================ 
 Conv2d-1               [-1, 64, 112, 112]                  9,408 
 BatchNorm2d-2          [-1, 64, 112, 112]                    128 
 ReLU-3                 [-1, 64, 112, 112]                      0 
 MaxPool2d-4              [-1, 64, 56, 56]                      0 
 Conv2d-5                 [-1, 64, 56, 56]                 36,864
 .
 .
 .
 AdaptiveAvgPool2d-67      [-1, 512, 1, 1]                      0
 Linear-68                      [-1, 1000]                513,000 ================================================================
 Total params: 11,689,512 
 Trainable params: 11,689,512 
 Non-trainable params: 0 
 ----------------------------------------------------------------
 Input size (MB): 0.57 
 Forward/backward pass size (MB): 62.79 
 Params size (MB): 44.59 
 Estimated Total Size (MB): 107.96 
 ----------------------------------------------------------------

现在,如果你的基本模型有多个分支,每个分支都有不同的输入,例如

 class Model(torch.nn.Module):
     def __init__(self):
         super().__init__()
         self.resnet1 = torchvision.models.resnet18().cuda()
         self.resnet2 = torchvision.models.resnet18().cuda()
         self.resnet3 = torchvision.models.resnet18().cuda()
     
     def forward(self, *x):
         out1 = self.resnet1(x[0])
         out2 = self.resnet2(x[1])
         out3 = self.resnet3(x[2])
         out = torch.cat([out1, out2, out3], dim = 0)
         return out

那么就需要这样:

 summary(Model().cuda(), input_size = [(3, 64, 64)]*3)

该输出将与前一个相似,但会有点混乱,因为torchsummary将每个组成的ResNet模块的信息压缩到一个摘要中,而在两个连续模块的摘要之间没有任何适当的可区分边界。

torchinfo

它看起来可能与torchsummary类似。但在我看来,它是我找到这三种方法中最好的。torchinfo当前版本是1.7.0,还是可以使用pip安装:

 pip install torchinfo

这个包也有一个名为summary的函数。但它有更多的参数。他的使用参数为model (nn.module)、input_size (Sequence of Sizes)、input_data (Sequence of Tensors)、batch_dim (int)、cache_forward_pass (bool)、col_names (Iterable[str])、col_width (int)、depth (int)、device (torch.Device)、dtypes (List[torch.dtype])、mode (str)、row_settings (Iterable[str])、verbose (int)和**kwargs。

参数很多,但是可以直接通过(" input_size ", " output_size ", " num_params ", " kernel_size ", " mult_add ", " trainable ")作为col_names参数来获取信息。

 import torchinfo
 torchinfo.summary(model, (3, 224, 224), batch_dim = 0, col_names = (“input_size”, “output_size”, “num_params”, “kernel_size”, “mult_adds”), verbose = 0)

需要说明的是,如果不使用Jupyter或Google Colab,需要将verbose 更改为1。

上述代码段的输出看起来像这样

 =============================================================================================
 Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
 =============================================================================================
 ResNet                                   [1, 3, 224, 224]          [1, 1000]                 --                        --                        --
 ├─Conv2d: 1-1                            [1, 3, 224, 224]          [1, 64, 112, 112]         9,408                     [7, 7]                    118,013,952
 ├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         [1, 64, 112, 112]         128                       --                        128
 ├─ReLU: 1-3                              [1, 64, 112, 112]         [1, 64, 112, 112]         --                        --                        --
 ├─MaxPool2d: 1-4                         [1, 64, 112, 112]         [1, 64, 56, 56]           --                        3                         --
 ├─Sequential: 1-5                        [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 │    └─BasicBlock: 2-1                   [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 │    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           [1, 64, 56, 56]           36,864                    [3, 3]                    115,605,504
 │    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           [1, 64, 56, 56]           128                       --                        128
 │    │    └─ReLU: 3-3                    [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 │    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           [1, 64, 56, 56]           36,864                    [3, 3]                    115,605,504
 │    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           [1, 64, 56, 56]           128                       --                        128
 │    │    └─ReLU: 3-6                    [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 │    └─BasicBlock: 2-2                   [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 │    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           [1, 64, 56, 56]           36,864                    [3, 3]                    115,605,504
 │    │    └─BatchNorm2d: 3-8             [1, 64, 56, 56]           [1, 64, 56, 56]           128                       --                        128
 │    │    └─ReLU: 3-9                    [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 │    │    └─Conv2d: 3-10                 [1, 64, 56, 56]           [1, 64, 56, 56]           36,864                    [3, 3]                    115,605,504
 │    │    └─BatchNorm2d: 3-11            [1, 64, 56, 56]           [1, 64, 56, 56]           128                       --                        128
 │    │    └─ReLU: 3-12                   [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 ├─Sequential: 1-6                        [1, 64, 56, 56]           [1, 128, 28, 28]          --                        --                        --
 │    └─BasicBlock: 2-3                   [1, 64, 56, 56]           [1, 128, 28, 28]          --                        --                        --
 │    │    └─Conv2d: 3-13                 [1, 64, 56, 56]           [1, 128, 28, 28]          73,728                    [3, 3]                    57,802,752
 │    │    └─BatchNorm2d: 3-14            [1, 128, 28, 28]          [1, 128, 28, 28]          256                       --                        256
 .
 .
 .
 │    │    └─Conv2d: 3-49                 [1, 512, 7, 7]            [1, 512, 7, 7]            2,359,296                 [3, 3]                    115,605,504
 │    │    └─BatchNorm2d: 3-50            [1, 512, 7, 7]            [1, 512, 7, 7]            1,024                     --                        1,024
 │    │    └─ReLU: 3-51                   [1, 512, 7, 7]            [1, 512, 7, 7]            --                        --                        --
 ├─AdaptiveAvgPool2d: 1-9                 [1, 512, 7, 7]            [1, 512, 1, 1]            --                        --                        --
 ├─Linear: 1-10                           [1, 512]                  [1, 1000]                 513,000                   --                        513,000
 =============================================================================================
 Total params: 11,689,512
 Trainable params: 11,689,512
 Non-trainable params: 0
 Total mult-adds (G): 1.81
 =============================================================================================
 Input size (MB): 0.60
 Forward/backward pass size (MB): 39.75
 Params size (MB): 46.76
 Estimated Total Size (MB): 87.11
 =============================================================================================

再继续查看多分支模型

 torchinfo.summary(Model().cuda(), [(3, 64, 64)]*3, batch_dim = 0, col_names = (“input_size”, “output_size”, “num_params”, “kernel_size”, “mult_adds”), verbose = 0)

产生以下输出

 =============================================================================================
 Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
 =============================================================================================
 Model                                         [1, 3, 64, 64]            [1, 1000]                 --                        --                        --
 ├─ResNet: 1-1                                 [1, 3, 64, 64]            [1, 1000]                 --                        --                        --
 │    └─Conv2d: 2-1                            [1, 3, 64, 64]            [1, 64, 32, 32]           9,408                     [7, 7]                    9,633,792
 │    └─BatchNorm2d: 2-2                       [1, 64, 32, 32]           [1, 64, 32, 32]           128                       --                        128
 │    └─ReLU: 2-3                              [1, 64, 32, 32]           [1, 64, 32, 32]           --                        --                        --
 │    └─MaxPool2d: 2-4                         [1, 64, 32, 32]           [1, 64, 16, 16]           --                        3                         --
 │    └─Sequential: 2-5                        [1, 64, 16, 16]           [1, 64, 16, 16]           --                        --                        --
 │    │    └─BasicBlock: 3-1                   [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624
 │    │    └─BasicBlock: 3-2                   [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624
 │    └─Sequential: 2-6                        [1, 64, 16, 16]           [1, 128, 8, 8]            --                        --                        --
 │    │    └─BasicBlock: 3-3                   [1, 64, 16, 16]           [1, 128, 8, 8]            230,144                   --                        14,680,832
 │    │    └─BasicBlock: 3-4                   [1, 128, 8, 8]            [1, 128, 8, 8]            295,424                   --                        18,874,880
 │    └─Sequential: 2-7                        [1, 128, 8, 8]            [1, 256, 4, 4]            --                        --                        --
 │    │    └─BasicBlock: 3-5                   [1, 128, 8, 8]            [1, 256, 4, 4]            919,040                   --                        14,681,600
 │    │    └─BasicBlock: 3-6                   [1, 256, 4, 4]            [1, 256, 4, 4]            1,180,672                 --                        18,875,392
 │    └─Sequential: 2-8                        [1, 256, 4, 4]            [1, 512, 2, 2]            --                        --                        --
 │    │    └─BasicBlock: 3-7                   [1, 256, 4, 4]            [1, 512, 2, 2]            3,673,088                 --                        14,683,136
 │    │    └─BasicBlock: 3-8                   [1, 512, 2, 2]            [1, 512, 2, 2]            4,720,640                 --                        18,876,416
 │    └─AdaptiveAvgPool2d: 2-9                 [1, 512, 2, 2]            [1, 512, 1, 1]            --                        --                        --
 │    └─Linear: 2-10                           [1, 512]                  [1, 1000]                 513,000                   --                        513,000
 ├─ResNet: 1-2                                 [1, 3, 64, 64]            [1, 1000]                 --                        --                        --
 │    └─Conv2d: 2-11                           [1, 3, 64, 64]            [1, 64, 32, 32]           9,408                     [7, 7]                    9,633,792
 │    └─BatchNorm2d: 2-12                      [1, 64, 32, 32]           [1, 64, 32, 32]           128                       --                        128
 │    └─ReLU: 2-13                             [1, 64, 32, 32]           [1, 64, 32, 32]           --                        --                        --
 │    └─MaxPool2d: 2-14                        [1, 64, 32, 32]           [1, 64, 16, 16]           --                        3                         --
 │    └─Sequential: 2-15                       [1, 64, 16, 16]           [1, 64, 16, 16]           --                        --                        --
 │    │    └─BasicBlock: 3-9                   [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624
 │    │    └─BasicBlock: 3-10                  [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624
 │    └─Sequential: 2-16                       [1, 64, 16, 16]           [1, 128, 8, 8]            --                        --                        --
 │    │    └─BasicBlock: 3-11                  [1, 64, 16, 16]           [1, 128, 8, 8]            230,144                   --                        14,680,832
 │    │    └─BasicBlock: 3-12                  [1, 128, 8, 8]            [1, 128, 8, 8]            295,424                   --                        18,874,880
 │    └─Sequential: 2-17                       [1, 128, 8, 8]            [1, 256, 4, 4]            --                        --                        --
 │    │    └─BasicBlock: 3-13                  [1, 128, 8, 8]            [1, 256, 4, 4]            919,040                   --                        14,681,600
 │    │    └─BasicBlock: 3-14                  [1, 256, 4, 4]            [1, 256, 4, 4]            1,180,672                 --                        18,875,392
 │    └─Sequential: 2-18                       [1, 256, 4, 4]            [1, 512, 2, 2]            --                        --                        --
 │    │    └─BasicBlock: 3-15                  [1, 256, 4, 4]            [1, 512, 2, 2]            3,673,088                 --                        14,683,136
 │    │    └─BasicBlock: 3-16                  [1, 512, 2, 2]            [1, 512, 2, 2]            4,720,640                 --                        18,876,416
 │    └─AdaptiveAvgPool2d: 2-19                [1, 512, 2, 2]            [1, 512, 1, 1]            --                        --                        --
 │    └─Linear: 2-20                           [1, 512]                  [1, 1000]                 513,000                   --                        513,000
 ├─ResNet: 1-3                                 [1, 3, 64, 64]            [1, 1000]                 --                        --                        --
 │    └─Conv2d: 2-21                           [1, 3, 64, 64]            [1, 64, 32, 32]           9,408                     [7, 7]                    9,633,792
 │    └─BatchNorm2d: 2-22                      [1, 64, 32, 32]           [1, 64, 32, 32]           128                       --                        128
 │    └─ReLU: 2-23                             [1, 64, 32, 32]           [1, 64, 32, 32]           --                        --                        --
 │    └─MaxPool2d: 2-24                        [1, 64, 32, 32]           [1, 64, 16, 16]           --                        3                         --
 │    └─Sequential: 2-25                       [1, 64, 16, 16]           [1, 64, 16, 16]           --                        --                        --
 │    │    └─BasicBlock: 3-17                  [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624
 │    │    └─BasicBlock: 3-18                  [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624
 │    └─Sequential: 2-26                       [1, 64, 16, 16]           [1, 128, 8, 8]            --                        --                        --
 │    │    └─BasicBlock: 3-19                  [1, 64, 16, 16]           [1, 128, 8, 8]            230,144                   --                        14,680,832
 │    │    └─BasicBlock: 3-20                  [1, 128, 8, 8]            [1, 128, 8, 8]            295,424                   --                        18,874,880
 │    └─Sequential: 2-27                       [1, 128, 8, 8]            [1, 256, 4, 4]            --                        --                        --
 │    │    └─BasicBlock: 3-21                  [1, 128, 8, 8]            [1, 256, 4, 4]            919,040                   --                        14,681,600
 │    │    └─BasicBlock: 3-22                  [1, 256, 4, 4]            [1, 256, 4, 4]            1,180,672                 --                        18,875,392
 │    └─Sequential: 2-28                       [1, 256, 4, 4]            [1, 512, 2, 2]            --                        --                        --
 │    │    └─BasicBlock: 3-23                  [1, 256, 4, 4]            [1, 512, 2, 2]            3,673,088                 --                        14,683,136
 │    │    └─BasicBlock: 3-24                  [1, 512, 2, 2]            [1, 512, 2, 2]            4,720,640                 --                        18,876,416
 │    └─AdaptiveAvgPool2d: 2-29                [1, 512, 2, 2]            [1, 512, 1, 1]            --                        --                        --
 │    └─Linear: 2-30                           [1, 512]                  [1, 1000]                 513,000                   --                        513,000
 =============================================================================================
 Total params: 35,068,536
 Trainable params: 35,068,536
 Non-trainable params: 0
 Total mult-adds (M): 445.71
 =============================================================================================
 Input size (MB): 0.15
 Forward/backward pass size (MB): 9.76
 Params size (MB): 140.27
 Estimated Total Size (MB): 150.18
 =============================================================================================

可以看到depth 参数的默认值为3。并且在可视化方向上,多分支被重新进行了组织并且以层次结构方式呈现,所以很容易区分,所以他的效果要比torchsummary好很多。

https://avoid.overfit.cn/post/bfed756d1d5147a89f6d8d911b6d29dd

作者:Siladittya Manna

版权声明:程序员胖胖胖虎阿 发表于 2022年9月28日 上午1:32。
转载请注明:Pytorch中获取模型摘要的3种方法 | 胖虎的工具箱-编程导航

相关文章

暂无评论

暂无评论...