mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-12 18:36:15 +00:00
skip torchrec unittests if not installed (#1790)
This commit is contained in:
parent
0b8161fab8
commit
32c1b843a9
@ -1,19 +1,26 @@
|
|||||||
from colossalai.fx.tracer import meta_patch
|
import pytest
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
|
||||||
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
|
||||||
import torch
|
import torch
|
||||||
from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor
|
|
||||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
from colossalai.fx.tracer import meta_patch
|
||||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
||||||
from torchrec.models import deepfm, dlrm
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
import colossalai.fx as fx
|
|
||||||
import pdb
|
try:
|
||||||
|
from torchrec.models import deepfm
|
||||||
|
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||||
|
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||||
|
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||||
|
NOT_TORCHREC = False
|
||||||
|
except ImportError:
|
||||||
|
NOT_TORCHREC = True
|
||||||
|
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
BATCH = 2
|
BATCH = 2
|
||||||
SHAPE = 10
|
SHAPE = 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
|
||||||
def test_torchrec_deepfm_models():
|
def test_torchrec_deepfm_models():
|
||||||
MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch]
|
MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch]
|
||||||
|
|
||||||
|
@ -1,19 +1,24 @@
|
|||||||
from colossalai.fx.tracer import meta_patch
|
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
|
||||||
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
|
||||||
import torch
|
import torch
|
||||||
from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor
|
|
||||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
|
||||||
from torchrec.models import deepfm, dlrm
|
try:
|
||||||
import colossalai.fx as fx
|
from torchrec.models import dlrm
|
||||||
import pdb
|
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||||
|
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||||
|
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||||
|
NOT_TORCHREC = False
|
||||||
|
except ImportError:
|
||||||
|
NOT_TORCHREC = True
|
||||||
|
|
||||||
|
import pytest
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
BATCH = 2
|
BATCH = 2
|
||||||
SHAPE = 10
|
SHAPE = 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
|
||||||
def test_torchrec_dlrm_models():
|
def test_torchrec_dlrm_models():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
dlrm.DLRM,
|
dlrm.DLRM,
|
||||||
|
Loading…
Reference in New Issue
Block a user