模型的计算量和参数
模型的计算量和参数
根据llm的scaling laws,影响输出的三个重要因素,计算量、参数量以及数据大小. 只要这三个因素大起来,一切都会好起来的. 如何评估一个模型在实际部署或者推理时的效果,计算量和参数都是重要因素,一个限制GPU/CPU,一个限制内存. 此外对于图像或视频,fps也是重要因素,对于大模型而言就是输出token的速度,说白了就是输出时间,影响了实时性和使用效果,个人认为这个因素也跟计算量和参数相关. 那么如何计算模型计算量和参数量呢? 根据字面意思,参数量好理解,无非几种网络结构,每种网络结构都或多或少有参数,每个参数当作f32或者其他类型计算即可. 当然实际推理或者部署完全可以使用量化的方法得到整数,甚至是i8类型.这样就把参数量降低了,此外浮点数变成了整数,事实上计算也降低了复杂度,因为浮点数计算在CPU/GPU上往往更复杂.
这样一想,计算参数量就比较简单了
def calc_param(model):
"""
Calculate the number of parameters in the model.
:param model: model
:return: number of parameters
"""
param_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(param_num*4/1024/1024, 'MB') # assume dtype float32
那如何计算计算量呢,由于浮点数和整数运算上在CPU计算的不同,我们只考虑更为复杂的浮点数计算,就有了FLOPs. FLOPS计算如下,卷积层和全连接计算不同

还有些地方使用的是MACs和MAdds,
MAC:Multiply Accumulate,乘加运算。乘积累加运算(英语:Multiply Accumulate, MAC)是在数字信号处理器或一些微处理器中的特殊运算。实现此运算操作的硬件电路单元,被称为“乘数累加器”。这种运算的操作,是将乘法的乘积结果和累加器的值相加,再存入累加器: $$ a ← a + b × c $$ 使用MAC可以将原本需要的两个指令操作减少到一个指令操作,从而提高运算效率。
MAdds 本质上与 MACs 相同,都是指一次乘法和一次加法的组合。术语 MAdds 更常见于一些文献中,尤其是早期的文献。实际上,在大多数情况下,MACs 和 MAdds 可以互换使用.
1个 MACs 包含一个乘法操作与一个加法操作,大约包含2个 FLOPs。因此,通常 MACs 与 FLOPs 存在一个2倍的关系。(但是很多时候又会把它们会混淆合在一起)
抽象地高度来说
- 计算量是指网络模型需要计算的运算次数,参数量是指网络模型自带的参数数量多少
- 计算量对应时间复杂度,参数量对应于空间复杂度
- 计算量决定了网络执行时间的长短,参数量决定了占用显存的量
5种方法获取Torch网络模型参数量计算量等信息_查看模型参数量-CSDN博客
这下就可以好好调库了.下面介绍可以直接使用的库
thop
pip install thop
Lyken17/pytorch-OpCounter: Count the MACs / FLOPs of your PyTorch model. (github.com)
from torchvision.models import resnet50
import torch
import torch.nn as nn
from thop import profile
def calc_param(model):
"""
Calculate the number of parameters in the model.
:param model: model
:return: number of parameters
"""
param_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(param_num*4/1024/1024, 'MB') # assume dtype float32
if __name__ == '__main__':
# n = nn.Conv2d(3, 3, 3)
# calc_param(n)
model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input,))
# print(macs, params)
print('FLOPs = ' + str(macs / 1000 ** 3) + 'G')
print('Params = ' + str(params / 1000 ** 2) + 'M')
ptflops
import torchvision.models as models
import torch
from ptflops import get_model_complexity_info
with torch.cuda.device(0):
net = models.densenet161()
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='pytorch'
print_per_layer_stat=True, verbose=True)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='aten'
print_per_layer_stat=True, verbose=True)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
torchstat
pip install torchstat
from torchstat import stat
import torchvision.models as models
model = models.resnet18()
stat(model, (3, 224, 224))
个人觉得thop比较好
此外还有torchsummary,不过更多用于计算参数量和查看网络结构的.这里不赘述.