mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
@@ -32,31 +31,34 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||
assert len(transformed_fx_out) == len(transformed_non_fx_out)
|
||||
if torch.is_tensor(fx_out):
|
||||
assert torch.allclose(
|
||||
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
fx_out, non_fx_out
|
||||
), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}"
|
||||
else:
|
||||
assert torch.allclose(
|
||||
fx_out.values(),
|
||||
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
fx_out.values(), non_fx_out.values()
|
||||
), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}"
|
||||
for key in transformed_fx_out.keys():
|
||||
fx_output_val = transformed_fx_out[key]
|
||||
non_fx_output_val = transformed_non_fx_out[key]
|
||||
if torch.is_tensor(fx_output_val):
|
||||
assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
|
||||
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
|
||||
assert torch.allclose(
|
||||
fx_output_val, non_fx_output_val, atol=1e-5
|
||||
), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}"
|
||||
else:
|
||||
assert torch.allclose(fx_output_val.values(), non_fx_output_val.values()
|
||||
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
assert torch.allclose(
|
||||
fx_output_val.values(), non_fx_output_val.values()
|
||||
), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}"
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def test_torchrec_deepfm_models():
|
||||
deepfm_models = model_zoo.get_sub_registry('deepfm')
|
||||
deepfm_models = model_zoo.get_sub_registry("deepfm")
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in deepfm_models.items():
|
||||
data = data_gen_fn()
|
||||
if attribute is not None and attribute.has_control_flow:
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()}
|
||||
meta_args = {k: v.to("meta") for k, v in data.items()}
|
||||
else:
|
||||
meta_args = None
|
||||
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
@@ -32,37 +31,40 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||
assert len(transformed_fx_out) == len(transformed_non_fx_out)
|
||||
if torch.is_tensor(fx_out):
|
||||
assert torch.allclose(
|
||||
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
fx_out, non_fx_out
|
||||
), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}"
|
||||
else:
|
||||
assert torch.allclose(
|
||||
fx_out.values(),
|
||||
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
fx_out.values(), non_fx_out.values()
|
||||
), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}"
|
||||
for key in transformed_fx_out.keys():
|
||||
fx_output_val = transformed_fx_out[key]
|
||||
non_fx_output_val = transformed_non_fx_out[key]
|
||||
if torch.is_tensor(fx_output_val):
|
||||
assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
|
||||
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
|
||||
assert torch.allclose(
|
||||
fx_output_val, non_fx_output_val, atol=1e-5
|
||||
), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}"
|
||||
else:
|
||||
assert torch.allclose(fx_output_val.values(), non_fx_output_val.values()
|
||||
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
assert torch.allclose(
|
||||
fx_output_val.values(), non_fx_output_val.values()
|
||||
), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}"
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def test_torchrec_dlrm_models():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
dlrm_models = model_zoo.get_sub_registry('dlrm')
|
||||
dlrm_models = model_zoo.get_sub_registry("dlrm")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in dlrm_models.items():
|
||||
data = data_gen_fn()
|
||||
|
||||
# dlrm_interactionarch is not supported
|
||||
# TODO(FrankLeeeee): support this model
|
||||
if name == 'dlrm_interactionarch':
|
||||
if name == "dlrm_interactionarch":
|
||||
continue
|
||||
|
||||
if attribute is not None and attribute.has_control_flow:
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()}
|
||||
meta_args = {k: v.to("meta") for k, v in data.items()}
|
||||
else:
|
||||
meta_args = None
|
||||
|
||||
|
Reference in New Issue
Block a user