skip torchrec unittests if not installed (#1790)

This commit is contained in:
Jiarui Fang 2022-11-02 14:44:32 +08:00 committed by GitHub
parent 0b8161fab8
commit 32c1b843a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 18 deletions

View File

@ -1,19 +1,26 @@
from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
import pytest
import torch
from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.models import deepfm, dlrm
import colossalai.fx as fx
import pdb
from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
from colossalai.fx.tracer.tracer import ColoTracer
try:
from torchrec.models import deepfm
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
NOT_TORCHREC = False
except ImportError:
NOT_TORCHREC = True
from torch.fx import GraphModule
BATCH = 2
SHAPE = 10
@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
def test_torchrec_deepfm_models():
MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch]

View File

@ -1,19 +1,24 @@
from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
import torch
from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.models import deepfm, dlrm
import colossalai.fx as fx
import pdb
from colossalai.fx.tracer.tracer import ColoTracer
try:
from torchrec.models import dlrm
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
NOT_TORCHREC = False
except ImportError:
NOT_TORCHREC = True
import pytest
from torch.fx import GraphModule
BATCH = 2
SHAPE = 10
@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
def test_torchrec_dlrm_models():
MODEL_LIST = [
dlrm.DLRM,