ColossalAI/tests/test_analyzer/test_subclasses/test_flop_tensor.py
Super Daniel fff98f06ed
[analyzer] a minimal implementation of static graph analyzer (#2852)
* [hotfix] meta tensor default device.

* [siu] add experimental submodules to main branch.

* [siu]

* [siu]

* [analyzer] init.

* [analyzer] readme.

* [analyzer] readme.

* [analyzer] readme.

* [analyzer] readme.

* [test] add test.

* Update symbolic_trace.py

* mark skip tests.

* try except.

* try except.

* try except.

* s

* init

* init

* fix

* skip

* skip

---------

Co-authored-by: Daniel Shao <superdainiu@MININT-PVARVID.fareast.corp.microsoft.com>
Co-authored-by: Daniel Shao <superdainiu@Daniels-Mac.local>
2023-03-10 13:21:05 +08:00

51 lines
1.7 KiB
Python

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tm
from .zoo import tm_models, tmm_models
try:
from colossalai._analyzer._subclasses import MetaTensorMode, flop_count
except:
pass
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models + tmm_models)
def test_flop_count_module(m):
x = torch.rand(2, 3, 224, 224)
with MetaTensorMode(): # save time for testing
module = m()
rs_fwd, rs_bwd = flop_count(module, x, verbose=True)
assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}'
assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}'
odd_cases = [
(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
'inplace': True
}),
(F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
'kernel_size': 3,
'stride': 2,
'padding': 1,
'dilation': 2
}),
(torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True),
torch.rand(2, 3, 224, 224, requires_grad=True)), {}),
]
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
def test_flop_count_function(func, args, kwargs):
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'
assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}'
if __name__ == '__main__':
test_flop_count_module(tm.resnet18, torch.rand(2, 3, 224, 224))
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})