mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -19,14 +19,21 @@ torch.manual_seed(MANUAL_SEED)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
@pytest.mark.skip('balance split v2 is not ready')
|
||||
@pytest.mark.skip("balance split v2 is not ready")
|
||||
def test_torchvision_models():
|
||||
MODEL_LIST = [
|
||||
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
||||
tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5
|
||||
tm.vgg11,
|
||||
tm.resnet18,
|
||||
tm.densenet121,
|
||||
tm.mobilenet_v3_small,
|
||||
tm.resnext50_32x4d,
|
||||
tm.wide_resnet50_2,
|
||||
tm.regnet_x_16gf,
|
||||
tm.efficientnet_b0,
|
||||
tm.mnasnet0_5,
|
||||
]
|
||||
|
||||
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
|
||||
if version.parse(torchvision.__version__) >= version.parse("0.12.0"):
|
||||
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
|
||||
|
||||
tracer = ColoTracer()
|
||||
@@ -57,10 +64,10 @@ def test_torchvision_models():
|
||||
output_part1 = model_part1(output_part0)
|
||||
else:
|
||||
if len(output_part0) > len(sig.parameters):
|
||||
output_part0 = output_part0[:len(sig.parameters)]
|
||||
output_part0 = output_part0[: len(sig.parameters)]
|
||||
output_part1 = model_part1(*output_part0)
|
||||
assert output.equal(output_part1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_torchvision_models()
|
||||
|
Reference in New Issue
Block a user