mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-08 00:18:25 +00:00
skip torchrec unittests if not installed (#1790)
This commit is contained in:
parent
0b8161fab8
commit
32c1b843a9
@ -1,19 +1,26 @@
|
||||
from colossalai.fx.tracer import meta_patch
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
||||
import pytest
|
||||
import torch
|
||||
from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor
|
||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||
from torchrec.models import deepfm, dlrm
|
||||
import colossalai.fx as fx
|
||||
import pdb
|
||||
|
||||
from colossalai.fx.tracer import meta_patch
|
||||
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
|
||||
try:
|
||||
from torchrec.models import deepfm
|
||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||
NOT_TORCHREC = False
|
||||
except ImportError:
|
||||
NOT_TORCHREC = True
|
||||
|
||||
from torch.fx import GraphModule
|
||||
|
||||
BATCH = 2
|
||||
SHAPE = 10
|
||||
|
||||
|
||||
@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
|
||||
def test_torchrec_deepfm_models():
|
||||
MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch]
|
||||
|
||||
|
@ -1,19 +1,24 @@
|
||||
from colossalai.fx.tracer import meta_patch
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
||||
import torch
|
||||
from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor
|
||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||
from torchrec.models import deepfm, dlrm
|
||||
import colossalai.fx as fx
|
||||
import pdb
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
|
||||
try:
|
||||
from torchrec.models import dlrm
|
||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||
NOT_TORCHREC = False
|
||||
except ImportError:
|
||||
NOT_TORCHREC = True
|
||||
|
||||
import pytest
|
||||
from torch.fx import GraphModule
|
||||
|
||||
BATCH = 2
|
||||
SHAPE = 10
|
||||
|
||||
|
||||
@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
|
||||
def test_torchrec_dlrm_models():
|
||||
MODEL_LIST = [
|
||||
dlrm.DLRM,
|
||||
|
Loading…
Reference in New Issue
Block a user