flops

paddle. flops ( net, input_size, custom_ops=None, print_detail=False ) [源代码]

打印网络的基础结构和参数信息。

参数

  • net (paddle.nn.Layer|paddle.static.Program) - 网络实例,必须是 paddle.nn.Layer 的子类或者静态图下的 paddle.static.Program。

  • input_size (list) - 输入 Tensor 的大小。注意:仅支持 batch_size=1。

  • custom_ops (dict,可选) - 字典,用于实现对自定义网络层的统计。字典的 key 为自定义网络层的 class,value 为统计网络层 flops 的函数,函数实现方法见示例代码。此参数仅在 net 为 paddle.nn.Layer 时生效。默认值:None。

  • print_detail (bool,可选) - bool 值,用于控制是否打印每个网络层的细节。默认值:False。

返回

int,网络模型的计算量。

代码示例

>>> import paddle >>> import paddle.nn as nn >>> class LeNet(nn.Layer): ...  def __init__(self, num_classes=10): ...  super().__init__() ...  self.num_classes = num_classes ...  self.features = nn.Sequential( ...  nn.Conv2D(1, 6, 3, stride=1, padding=1), ...  nn.ReLU(), ...  nn.MaxPool2D(2, 2), ...  nn.Conv2D(6, 16, 5, stride=1, padding=0), ...  nn.ReLU(), ...  nn.MaxPool2D(2, 2)) ... ...  if num_classes > 0: ...  self.fc = nn.Sequential( ...  nn.Linear(400, 120), ...  nn.Linear(120, 84), ...  nn.Linear(84, 10)) ... ...  def forward(self, inputs): ...  x = self.features(inputs) ... ...  if self.num_classes > 0: ...  x = paddle.flatten(x, 1) ...  x = self.fc(x) ...  return x ... >>> lenet = LeNet() >>> # m is the instance of nn.Layer, x is the input of layer, y is the output of layer. >>> def count_leaky_relu(m, x, y): ...  x = x[0] ...  nelements = x.numel() ...  m.total_ops += int(nelements) ... >>> FLOPs = paddle.flops(lenet, ...  [1, 1, 28, 28], ...  custom_ops= {nn.LeakyReLU: count_leaky_relu}, ...  print_detail=True) <class 'paddle.nn.layer.conv.Conv2D'>'s flops has been counted <class 'paddle.nn.layer.activation.ReLU'>'s flops has been counted Cannot find suitable count function for <class 'paddle.nn.layer.pooling.MaxPool2D'>. Treat it as zero FLOPs. <class 'paddle.nn.layer.common.Linear'>'s flops has been counted +--------------+-----------------+-----------------+--------+--------+ | Layer Name | Input Shape | Output Shape | Params | Flops | +--------------+-----------------+-----------------+--------+--------+ | conv2d_0 | [1, 1, 28, 28] | [1, 6, 28, 28] | 60 | 47040 | | re_lu_0 | [1, 6, 28, 28] | [1, 6, 28, 28] | 0 | 0 | | max_pool2d_0 | [1, 6, 28, 28] | [1, 6, 14, 14] | 0 | 0 | | conv2d_1 | [1, 6, 14, 14] | [1, 16, 10, 10] | 2416 | 241600 | | re_lu_1 | [1, 16, 10, 10] | [1, 16, 10, 10] | 0 | 0 | | max_pool2d_1 | [1, 16, 10, 10] | [1, 16, 5, 5] | 0 | 0 | | linear_0 | [1, 400] | [1, 120] | 48120 | 48000 | | linear_1 | [1, 120] | [1, 84] | 10164 | 10080 | | linear_2 | [1, 84] | [1, 10] | 850 | 840 | +--------------+-----------------+-----------------+--------+--------+ Total Flops: 347560 Total Params: 61610 >>> print(FLOPs) 347560