[test] fixed tests failed due to dtensor change (#4082)

* [test] fixed tests failed due to dtensor change

* polish code
This commit is contained in:
Frank Lee
2023-06-26 15:50:07 +08:00
parent 92f6791095
commit c4b1b65931
37 changed files with 233 additions and 289 deletions

View File

@@ -56,7 +56,7 @@ def test_timm_models():
sub_model_zoo = model_zoo.get_sub_registry('timm')
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()}