mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -20,7 +20,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||
# 1. ConViT
|
||||
# 2. NormFreeNet
|
||||
# as they are not supported, let's skip them
|
||||
if model.__class__.__name__ in ['ConViT', 'NormFreeNet']:
|
||||
if model.__class__.__name__ in ["ConViT", "NormFreeNet"]:
|
||||
return
|
||||
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
@@ -39,8 +39,9 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||
for key in transformed_fx_out.keys():
|
||||
fx_output_val = transformed_fx_out[key]
|
||||
non_fx_output_val = transformed_non_fx_out[key]
|
||||
assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
|
||||
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
|
||||
assert torch.allclose(
|
||||
fx_output_val, non_fx_output_val, atol=1e-5
|
||||
), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}"
|
||||
|
||||
|
||||
# FIXME(ver217): timm/models/convit.py:71: in forward
|
||||
@@ -49,22 +50,22 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||
# return self.tracer.to_bool(self)
|
||||
# torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
|
||||
@pytest.mark.skip("convit is not supported yet")
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12")
|
||||
@clear_cache_before_run()
|
||||
def test_timm_models():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('timm')
|
||||
sub_model_zoo = model_zoo.get_sub_registry("timm")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
|
||||
data = data_gen_fn()
|
||||
if attribute is not None and attribute.has_control_flow:
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()}
|
||||
meta_args = {k: v.to("meta") for k, v in data.items()}
|
||||
else:
|
||||
meta_args = None
|
||||
|
||||
trace_and_compare(model_fn, data, output_transform_fn, meta_args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_timm_models()
|
||||
|
Reference in New Issue
Block a user