Simple pytorch utility that estimates the number of FLOPs for a given network. For now only some basic operations are supported (basically the ones I needed for my models). More will be added soon.
All contributions are welcomed.
You can install the model using pip:
pip install pthflops
or directly from the github repository:
git clone https://github.com/1adrianb/pytorch-estimate-flops && pytorch-estimate-flops python setup.py install
Note: pytorch 1.8 or newer is recommended.
import torch from torchvision.models import resnet18 from pthflops import count_ops # Create a network and a corresponding input device = 'cuda:0' model = resnet18().to(device) inp = torch.rand(1,3,224,224).to(device) # Count the number of FLOPs count_ops(model, inp)
Ignoring certain layers:
import torch from torch import nn from pthflops import count_ops class CustomLayer(nn.Module): def __init__(self): super(CustomLayer, self).__init__() self.conv1 = nn.Conv2d(5, 5, 1, 1, 0) # ... other layers present inside will also be ignored def forward(self, x): return self.conv1(x) # Create a network and a corresponding input inp = torch.rand(1,5,7,7) net = nn.Sequential( nn.Conv2d(5, 5, 1, 1, 0), nn.ReLU(inplace=True), CustomLayer() ) # Count the number of FLOPs, jit mode: count_ops(net, inp, ignore_layers=['CustomLayer']) # Note: if you are using python 1.8 or newer with fx instead of jit, the naming convention changed. As such, you will have to pass ['_2_conv1'] # Please check your model definition to account for this. # Count the number of FLOPs, fx mode: count_ops(net, inp, ignore_layers=['_2_conv1'])