mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from torchaudio_utils import trace_and_compare
|
||||
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
@@ -14,11 +13,10 @@ from tests.kit.model_zoo import model_zoo
|
||||
def test_torchaudio_models():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
|
||||
sub_model_zoo = model_zoo.get_sub_registry("torchaudio")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
|
||||
model = model_fn()
|
||||
trace_and_compare(model,
|
||||
data_gen_fn,
|
||||
output_transform_fn,
|
||||
need_meta=(attribute is not None and attribute.has_control_flow))
|
||||
trace_and_compare(
|
||||
model, data_gen_fn, output_transform_fn, need_meta=(attribute is not None and attribute.has_control_flow)
|
||||
)
|
||||
|
@@ -6,7 +6,7 @@ from colossalai._analyzer.fx import symbolic_trace
|
||||
def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False):
|
||||
data = data_gen()
|
||||
concrete_args = data if need_concrete else {}
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
|
||||
meta_args = {k: v.to("meta") for k, v in data.items()} if need_meta else {}
|
||||
|
||||
model.eval()
|
||||
|
||||
@@ -24,5 +24,6 @@ def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, nee
|
||||
|
||||
for key, fx_output_val in transformed_fx_out.items():
|
||||
non_fx_output_val = transformed_non_fx_out[key]
|
||||
assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
|
||||
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
|
||||
assert torch.allclose(
|
||||
fx_output_val, non_fx_output_val, atol=1e-5
|
||||
), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}"
|
||||
|
Reference in New Issue
Block a user