diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index 0f1f294e4..d2efc3c45 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -1,19 +1,26 @@ -from colossalai.fx.tracer import meta_patch -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.fx.tracer.meta_patch.patched_function import python_ops +import pytest import torch -from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor -from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.models import deepfm, dlrm -import colossalai.fx as fx -import pdb + +from colossalai.fx.tracer import meta_patch +from colossalai.fx.tracer.meta_patch.patched_function import python_ops +from colossalai.fx.tracer.tracer import ColoTracer + +try: + from torchrec.models import deepfm + from torchrec.modules.embedding_configs import EmbeddingBagConfig + from torchrec.modules.embedding_modules import EmbeddingBagCollection + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + NOT_TORCHREC = False +except ImportError: + NOT_TORCHREC = True + from torch.fx import GraphModule BATCH = 2 SHAPE = 10 +@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed') def test_torchrec_deepfm_models(): MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch] diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 5999a1abf..4050c7d3c 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -1,19 +1,24 @@ -from colossalai.fx.tracer import meta_patch -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.fx.tracer.meta_patch.patched_function import python_ops import torch -from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor -from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.models import deepfm, dlrm -import colossalai.fx as fx -import pdb + +from colossalai.fx.tracer.tracer import ColoTracer + +try: + from torchrec.models import dlrm + from torchrec.modules.embedding_configs import EmbeddingBagConfig + from torchrec.modules.embedding_modules import EmbeddingBagCollection + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + NOT_TORCHREC = False +except ImportError: + NOT_TORCHREC = True + +import pytest from torch.fx import GraphModule BATCH = 2 SHAPE = 10 +@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed') def test_torchrec_dlrm_models(): MODEL_LIST = [ dlrm.DLRM,