mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-21 19:43:11 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -4,7 +4,6 @@ import torch.nn.functional as F
|
||||
import torchvision.models as tm
|
||||
from packaging import version
|
||||
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
|
||||
|
||||
try:
|
||||
@@ -13,40 +12,44 @@ except:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
|
||||
@pytest.mark.parametrize("m", tm_models + tmm_models)
|
||||
def test_flop_count_module(m):
|
||||
x = torch.rand(2, 3, 224, 224)
|
||||
with MetaTensorMode(): # save time for testing
|
||||
with MetaTensorMode(): # save time for testing
|
||||
module = m()
|
||||
rs_fwd, rs_bwd = flop_count(module, x, verbose=True)
|
||||
assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}'
|
||||
assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}'
|
||||
assert rs_fwd > 0, f"fwd flop count of {m.__name__} is {rs_fwd}"
|
||||
assert rs_bwd > 0, f"bwd flop count of {m.__name__} is {rs_bwd}"
|
||||
|
||||
|
||||
odd_cases = [
|
||||
(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
|
||||
'inplace': True
|
||||
}),
|
||||
(F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
|
||||
'kernel_size': 3,
|
||||
'stride': 2,
|
||||
'padding': 1,
|
||||
'dilation': 2
|
||||
}),
|
||||
(torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True),
|
||||
torch.rand(2, 3, 224, 224, requires_grad=True)), {}),
|
||||
(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True}),
|
||||
(
|
||||
F.max_pool2d,
|
||||
(torch.rand(2, 3, 224, 224, requires_grad=True),),
|
||||
{"kernel_size": 3, "stride": 2, "padding": 1, "dilation": 2},
|
||||
),
|
||||
(
|
||||
torch.where,
|
||||
(
|
||||
torch.rand(2, 3, 224, 224) > 0.5,
|
||||
torch.rand(2, 3, 224, 224, requires_grad=True),
|
||||
torch.rand(2, 3, 224, 224, requires_grad=True),
|
||||
),
|
||||
{},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
|
||||
@pytest.mark.parametrize("func, args, kwargs", odd_cases)
|
||||
def test_flop_count_function(func, args, kwargs):
|
||||
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
|
||||
assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'
|
||||
assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}'
|
||||
assert rs_fwd > 0, f"fwd flop count of {func.__name__} is {rs_fwd}"
|
||||
assert rs_bwd > 0, f"bwd flop count of {func.__name__} is {rs_bwd}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_flop_count_module(tm.resnet18)
|
||||
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})
|
||||
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True})
|
||||
|
||||
Reference in New Issue
Block a user