mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-30 21:39:05 +00:00 
			
		
		
		
	skip torchrec unittests if not installed (#1790)
This commit is contained in:
		| @@ -1,19 +1,26 @@ | |||||||
| from colossalai.fx.tracer import meta_patch | import pytest | ||||||
| from colossalai.fx.tracer.tracer import ColoTracer |  | ||||||
| from colossalai.fx.tracer.meta_patch.patched_function import python_ops |  | ||||||
| import torch | import torch | ||||||
| from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor |  | ||||||
| from torchrec.modules.embedding_modules import EmbeddingBagCollection | from colossalai.fx.tracer import meta_patch | ||||||
|  | from colossalai.fx.tracer.meta_patch.patched_function import python_ops | ||||||
|  | from colossalai.fx.tracer.tracer import ColoTracer | ||||||
|  |  | ||||||
|  | try: | ||||||
|  |     from torchrec.models import deepfm | ||||||
|     from torchrec.modules.embedding_configs import EmbeddingBagConfig |     from torchrec.modules.embedding_configs import EmbeddingBagConfig | ||||||
| from torchrec.models import deepfm, dlrm |     from torchrec.modules.embedding_modules import EmbeddingBagCollection | ||||||
| import colossalai.fx as fx |     from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor | ||||||
| import pdb |     NOT_TORCHREC = False | ||||||
|  | except ImportError: | ||||||
|  |     NOT_TORCHREC = True | ||||||
|  |  | ||||||
| from torch.fx import GraphModule | from torch.fx import GraphModule | ||||||
|  |  | ||||||
| BATCH = 2 | BATCH = 2 | ||||||
| SHAPE = 10 | SHAPE = 10 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed') | ||||||
| def test_torchrec_deepfm_models(): | def test_torchrec_deepfm_models(): | ||||||
|     MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch] |     MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch] | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,19 +1,24 @@ | |||||||
| from colossalai.fx.tracer import meta_patch |  | ||||||
| from colossalai.fx.tracer.tracer import ColoTracer |  | ||||||
| from colossalai.fx.tracer.meta_patch.patched_function import python_ops |  | ||||||
| import torch | import torch | ||||||
| from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor |  | ||||||
| from torchrec.modules.embedding_modules import EmbeddingBagCollection | from colossalai.fx.tracer.tracer import ColoTracer | ||||||
|  |  | ||||||
|  | try: | ||||||
|  |     from torchrec.models import dlrm | ||||||
|     from torchrec.modules.embedding_configs import EmbeddingBagConfig |     from torchrec.modules.embedding_configs import EmbeddingBagConfig | ||||||
| from torchrec.models import deepfm, dlrm |     from torchrec.modules.embedding_modules import EmbeddingBagCollection | ||||||
| import colossalai.fx as fx |     from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor | ||||||
| import pdb |     NOT_TORCHREC = False | ||||||
|  | except ImportError: | ||||||
|  |     NOT_TORCHREC = True | ||||||
|  |  | ||||||
|  | import pytest | ||||||
| from torch.fx import GraphModule | from torch.fx import GraphModule | ||||||
|  |  | ||||||
| BATCH = 2 | BATCH = 2 | ||||||
| SHAPE = 10 | SHAPE = 10 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed') | ||||||
| def test_torchrec_dlrm_models(): | def test_torchrec_dlrm_models(): | ||||||
|     MODEL_LIST = [ |     MODEL_LIST = [ | ||||||
|         dlrm.DLRM, |         dlrm.DLRM, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user