mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 03:43:01 +00:00
* [hotfix] meta tensor default device. * [siu] add experimental submodules to main branch. * [siu] * [siu] * [analyzer] init. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [test] add test. * Update symbolic_trace.py * mark skip tests. * try except. * try except. * try except. * s * init * init * fix * skip * skip --------- Co-authored-by: Daniel Shao <superdainiu@MININT-PVARVID.fareast.corp.microsoft.com> Co-authored-by: Daniel Shao <superdainiu@Daniels-Mac.local>
51 lines
1.7 KiB
Python
51 lines
1.7 KiB
Python
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision.models as tm
|
|
from .zoo import tm_models, tmm_models
|
|
|
|
try:
|
|
from colossalai._analyzer._subclasses import MetaTensorMode, flop_count
|
|
except:
|
|
pass
|
|
|
|
|
|
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
|
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
|
def test_flop_count_module(m):
|
|
x = torch.rand(2, 3, 224, 224)
|
|
with MetaTensorMode(): # save time for testing
|
|
module = m()
|
|
rs_fwd, rs_bwd = flop_count(module, x, verbose=True)
|
|
assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}'
|
|
assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}'
|
|
|
|
|
|
odd_cases = [
|
|
(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
|
|
'inplace': True
|
|
}),
|
|
(F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
|
|
'kernel_size': 3,
|
|
'stride': 2,
|
|
'padding': 1,
|
|
'dilation': 2
|
|
}),
|
|
(torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True),
|
|
torch.rand(2, 3, 224, 224, requires_grad=True)), {}),
|
|
]
|
|
|
|
|
|
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
|
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
|
|
def test_flop_count_function(func, args, kwargs):
|
|
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
|
|
assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'
|
|
assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}'
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_flop_count_module(tm.resnet18, torch.rand(2, 3, 224, 224))
|
|
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})
|