mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -14,35 +14,41 @@ except:
|
||||
aten = torch.ops.aten
|
||||
|
||||
registered_meta = {
|
||||
('aten.convolution.default', True): [ # (aten ops, requires_backward)
|
||||
("aten.convolution.default", True): [ # (aten ops, requires_backward)
|
||||
(nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)),
|
||||
(nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)),
|
||||
(nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)),
|
||||
(nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)),
|
||||
(nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1,
|
||||
dilation=2), torch.rand(2, 3, 4, 4)),
|
||||
(nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1,
|
||||
dilation=2), torch.rand(2, 3, 4, 4, 4)),
|
||||
(
|
||||
nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2),
|
||||
torch.rand(2, 3, 4, 4),
|
||||
),
|
||||
(
|
||||
nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2),
|
||||
torch.rand(2, 3, 4, 4, 4),
|
||||
),
|
||||
],
|
||||
('aten.native_batch_norm.default', True): [
|
||||
("aten.native_batch_norm.default", True): [
|
||||
(nn.BatchNorm1d(4), torch.rand(2, 4)),
|
||||
(nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)),
|
||||
(nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)),
|
||||
],
|
||||
('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),],
|
||||
('aten.avg_pool1d.default', True): [
|
||||
("aten.native_layer_norm.default", True): [
|
||||
(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),
|
||||
],
|
||||
("aten.avg_pool1d.default", True): [
|
||||
(nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)),
|
||||
(nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)),
|
||||
(nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)),
|
||||
(nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)),
|
||||
],
|
||||
('aten.avg_pool2d.default', True): [
|
||||
("aten.avg_pool2d.default", True): [
|
||||
(nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)),
|
||||
(nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)),
|
||||
(nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)),
|
||||
(nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)),
|
||||
],
|
||||
('aten.relu.default', True): [
|
||||
("aten.relu.default", True): [
|
||||
(nn.ReLU(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.LeakyReLU(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.SiLU(), torch.rand(4, 3, 1, 2)),
|
||||
@@ -51,15 +57,20 @@ registered_meta = {
|
||||
(nn.Sigmoid(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.Tanh(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.Hardswish(), torch.rand(4, 3, 1, 2)),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
|
||||
assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
|
||||
assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
|
||||
assert tensor.stride() == meta_tensor.stride(
|
||||
), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.'
|
||||
assert (
|
||||
tensor.shape == meta_tensor.shape
|
||||
), f"the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match."
|
||||
assert (
|
||||
tensor.dtype == meta_tensor.dtype
|
||||
), f"the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match."
|
||||
assert (
|
||||
tensor.stride() == meta_tensor.stride()
|
||||
), f"the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match."
|
||||
|
||||
|
||||
def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
|
||||
@@ -73,7 +84,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
|
||||
compare_all(x.grad, meta_x.grad)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="torch version < 12")
|
||||
@clear_cache_before_run()
|
||||
def test_meta_aten():
|
||||
for (aten_op, requires_backward), v in registered_meta.items():
|
||||
@@ -81,5 +92,5 @@ def test_meta_aten():
|
||||
run_and_compare(f, x, requires_backward)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_meta_aten()
|
||||
|
@@ -4,7 +4,6 @@ import torch.nn.functional as F
|
||||
import torchvision.models as tm
|
||||
from packaging import version
|
||||
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
|
||||
|
||||
try:
|
||||
@@ -13,40 +12,44 @@ except:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
|
||||
@pytest.mark.parametrize("m", tm_models + tmm_models)
|
||||
def test_flop_count_module(m):
|
||||
x = torch.rand(2, 3, 224, 224)
|
||||
with MetaTensorMode(): # save time for testing
|
||||
with MetaTensorMode(): # save time for testing
|
||||
module = m()
|
||||
rs_fwd, rs_bwd = flop_count(module, x, verbose=True)
|
||||
assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}'
|
||||
assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}'
|
||||
assert rs_fwd > 0, f"fwd flop count of {m.__name__} is {rs_fwd}"
|
||||
assert rs_bwd > 0, f"bwd flop count of {m.__name__} is {rs_bwd}"
|
||||
|
||||
|
||||
odd_cases = [
|
||||
(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
|
||||
'inplace': True
|
||||
}),
|
||||
(F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
|
||||
'kernel_size': 3,
|
||||
'stride': 2,
|
||||
'padding': 1,
|
||||
'dilation': 2
|
||||
}),
|
||||
(torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True),
|
||||
torch.rand(2, 3, 224, 224, requires_grad=True)), {}),
|
||||
(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True}),
|
||||
(
|
||||
F.max_pool2d,
|
||||
(torch.rand(2, 3, 224, 224, requires_grad=True),),
|
||||
{"kernel_size": 3, "stride": 2, "padding": 1, "dilation": 2},
|
||||
),
|
||||
(
|
||||
torch.where,
|
||||
(
|
||||
torch.rand(2, 3, 224, 224) > 0.5,
|
||||
torch.rand(2, 3, 224, 224, requires_grad=True),
|
||||
torch.rand(2, 3, 224, 224, requires_grad=True),
|
||||
),
|
||||
{},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
|
||||
@pytest.mark.parametrize("func, args, kwargs", odd_cases)
|
||||
def test_flop_count_function(func, args, kwargs):
|
||||
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
|
||||
assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'
|
||||
assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}'
|
||||
assert rs_fwd > 0, f"fwd flop count of {func.__name__} is {rs_fwd}"
|
||||
assert rs_bwd > 0, f"bwd flop count of {func.__name__} is {rs_bwd}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_flop_count_module(tm.resnet18)
|
||||
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})
|
||||
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True})
|
||||
|
@@ -6,17 +6,22 @@ from packaging import version
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
|
||||
from colossalai._analyzer._subclasses import MetaTensorMode
|
||||
except:
|
||||
pass
|
||||
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
|
||||
|
||||
|
||||
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor):
|
||||
assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
|
||||
assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
|
||||
assert tensor.stride() == meta_tensor.stride(
|
||||
), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.'
|
||||
assert (
|
||||
tensor.shape == meta_tensor.shape
|
||||
), f"the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match."
|
||||
assert (
|
||||
tensor.dtype == meta_tensor.dtype
|
||||
), f"the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match."
|
||||
assert (
|
||||
tensor.stride() == meta_tensor.stride()
|
||||
), f"the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match."
|
||||
|
||||
|
||||
def run_and_compare(model):
|
||||
@@ -31,12 +36,12 @@ def run_and_compare(model):
|
||||
compare_all(x.grad, meta_x.grad)
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('m', tm_models + tmm_models)
|
||||
@parameterize("m", tm_models + tmm_models)
|
||||
def test_meta_mode_shape(m):
|
||||
run_and_compare(m())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_meta_mode_shape(tm.resnet18)
|
||||
|
Reference in New Issue
Block a user