mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 20:23:41 +00:00
[test] add torchrec models to test model zoo (#3139)
This commit is contained in:
parent
14a115000b
commit
ecd643f1e4
@ -1,4 +1,5 @@
|
|||||||
from . import diffusers, timm, torchaudio, torchvision, transformers
|
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
|
||||||
|
|
||||||
from .registry import model_zoo
|
from .registry import model_zoo
|
||||||
|
|
||||||
__all__ = ['model_zoo']
|
__all__ = ['model_zoo']
|
||||||
|
97
tests/kit/model_zoo/torchrec/torchrec.py
Normal file
97
tests/kit/model_zoo/torchrec/torchrec.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
from collections import namedtuple
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torchrec.models import deepfm, dlrm
|
||||||
|
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||||
|
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||||
|
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||||
|
NO_TORCHREC = False
|
||||||
|
except ImportError:
|
||||||
|
NO_TORCHREC = True
|
||||||
|
|
||||||
|
from ..registry import ModelAttribute, model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
def register_torchrec_models():
|
||||||
|
BATCH = 2
|
||||||
|
SHAPE = 10
|
||||||
|
# KeyedTensor
|
||||||
|
KT = KeyedTensor(keys=["f1", "f2"], length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
|
||||||
|
|
||||||
|
# KeyedJaggedTensor
|
||||||
|
KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"],
|
||||||
|
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
|
||||||
|
offsets=torch.tensor([0, 2, 4, 6, 8]))
|
||||||
|
|
||||||
|
data_gen_fn = lambda: dict(features=torch.rand((BATCH, SHAPE)))
|
||||||
|
|
||||||
|
interaction_arch_data_gen_fn = lambda: dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KT)
|
||||||
|
|
||||||
|
simple_dfm_data_gen_fn = lambda: dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KJT)
|
||||||
|
|
||||||
|
sparse_arch_data_gen_fn = lambda: dict(features=KJT)
|
||||||
|
|
||||||
|
output_transform_fn = lambda x: dict(output=x)
|
||||||
|
|
||||||
|
def get_ebc():
|
||||||
|
# EmbeddingBagCollection
|
||||||
|
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
|
||||||
|
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
|
||||||
|
return EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||||
|
|
||||||
|
model_zoo.register(name='deepfm_densearch',
|
||||||
|
model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE),
|
||||||
|
data_gen_fn=data_gen_fn,
|
||||||
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
|
model_zoo.register(name='deepfm_interactionarch',
|
||||||
|
model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE),
|
||||||
|
data_gen_fn=interaction_arch_data_gen_fn,
|
||||||
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
|
model_zoo.register(name='deepfm_overarch',
|
||||||
|
model_fn=partial(deepfm.OverArch, SHAPE),
|
||||||
|
data_gen_fn=data_gen_fn,
|
||||||
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
|
model_zoo.register(name='deepfm_simpledeepfmnn',
|
||||||
|
model_fn=partial(deepfm.SimpleDeepFMNN, SHAPE, get_ebc(), SHAPE, SHAPE),
|
||||||
|
data_gen_fn=simple_dfm_data_gen_fn,
|
||||||
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
|
model_zoo.register(name='deepfm_sparsearch',
|
||||||
|
model_fn=partial(deepfm.SparseArch, get_ebc()),
|
||||||
|
data_gen_fn=sparse_arch_data_gen_fn,
|
||||||
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
|
model_zoo.register(name='dlrm',
|
||||||
|
model_fn=partial(dlrm.DLRM, get_ebc(), SHAPE, [SHAPE, SHAPE], [5, 1]),
|
||||||
|
data_gen_fn=simple_dfm_data_gen_fn,
|
||||||
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
|
model_zoo.register(name='dlrm_densearch',
|
||||||
|
model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]),
|
||||||
|
data_gen_fn=data_gen_fn,
|
||||||
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
|
model_zoo.register(name='dlrm_interactionarch',
|
||||||
|
model_fn=partial(dlrm.InteractionArch, 2),
|
||||||
|
data_gen_fn=interaction_arch_data_gen_fn,
|
||||||
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
|
model_zoo.register(name='dlrm_overarch',
|
||||||
|
model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]),
|
||||||
|
data_gen_fn=data_gen_fn,
|
||||||
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
|
model_zoo.register(name='dlrm_sparsearch',
|
||||||
|
model_fn=partial(dlrm.SparseArch, get_ebc()),
|
||||||
|
data_gen_fn=sparse_arch_data_gen_fn,
|
||||||
|
output_transform_fn=output_transform_fn)
|
||||||
|
|
||||||
|
|
||||||
|
if not NO_TORCHREC:
|
||||||
|
register_torchrec_models()
|
@ -2,77 +2,38 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
from colossalai.fx import symbolic_trace
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
try:
|
|
||||||
from torchrec.models import deepfm
|
|
||||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
|
||||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
|
||||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
|
||||||
NOT_TORCHREC = False
|
|
||||||
except ImportError:
|
|
||||||
NOT_TORCHREC = True
|
|
||||||
|
|
||||||
BATCH = 2
|
BATCH = 2
|
||||||
SHAPE = 10
|
SHAPE = 10
|
||||||
|
|
||||||
|
deepfm_models = model_zoo.get_sub_registry('deepfm')
|
||||||
|
NOT_DFM = False
|
||||||
|
if not deepfm_models:
|
||||||
|
NOT_DFM = True
|
||||||
|
|
||||||
@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
|
|
||||||
def test_torchrec_deepfm_models():
|
|
||||||
MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch]
|
|
||||||
|
|
||||||
# Data Preparation
|
def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||||
# EmbeddingBagCollection
|
# trace
|
||||||
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
|
model = model_cls()
|
||||||
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
|
|
||||||
|
|
||||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
|
||||||
keys = ["f1", "f2"]
|
|
||||||
|
|
||||||
# KeyedTensor
|
|
||||||
KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
|
|
||||||
|
|
||||||
# KeyedJaggedTensor
|
|
||||||
KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys,
|
|
||||||
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
|
|
||||||
offsets=torch.tensor([0, 2, 4, 6, 8]))
|
|
||||||
|
|
||||||
# Dense Features
|
|
||||||
features = torch.rand((BATCH, SHAPE))
|
|
||||||
|
|
||||||
for model_cls in MODEL_LIST:
|
|
||||||
# Initializing model
|
|
||||||
if model_cls == deepfm.DenseArch:
|
|
||||||
model = model_cls(SHAPE, SHAPE, SHAPE)
|
|
||||||
elif model_cls == deepfm.FMInteractionArch:
|
|
||||||
model = model_cls(SHAPE * 3, keys, SHAPE)
|
|
||||||
elif model_cls == deepfm.OverArch:
|
|
||||||
model = model_cls(SHAPE)
|
|
||||||
elif model_cls == deepfm.SimpleDeepFMNN:
|
|
||||||
model = model_cls(SHAPE, ebc, SHAPE, SHAPE)
|
|
||||||
elif model_cls == deepfm.SparseArch:
|
|
||||||
model = model_cls(ebc)
|
|
||||||
|
|
||||||
# Setup GraphModule
|
|
||||||
gm = symbolic_trace(model)
|
|
||||||
|
|
||||||
|
# convert to eval for inference
|
||||||
|
# it is important to set it to eval mode before tracing
|
||||||
|
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
gm = symbolic_trace(model, meta_args=meta_args)
|
||||||
gm.eval()
|
gm.eval()
|
||||||
|
# run forward
|
||||||
# Aligned Test
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if model_cls == deepfm.DenseArch or model_cls == deepfm.OverArch:
|
fx_out = gm(**data)
|
||||||
fx_out = gm(features)
|
non_fx_out = model(**data)
|
||||||
non_fx_out = model(features)
|
|
||||||
elif model_cls == deepfm.FMInteractionArch:
|
|
||||||
fx_out = gm(features, KT)
|
|
||||||
non_fx_out = model(features, KT)
|
|
||||||
elif model_cls == deepfm.SimpleDeepFMNN:
|
|
||||||
fx_out = gm(features, KJT)
|
|
||||||
non_fx_out = model(features, KJT)
|
|
||||||
elif model_cls == deepfm.SparseArch:
|
|
||||||
fx_out = gm(KJT)
|
|
||||||
non_fx_out = model(KJT)
|
|
||||||
|
|
||||||
|
# compare output
|
||||||
|
transformed_fx_out = output_transform_fn(fx_out)
|
||||||
|
transformed_non_fx_out = output_transform_fn(non_fx_out)
|
||||||
|
|
||||||
|
assert len(transformed_fx_out) == len(transformed_non_fx_out)
|
||||||
if torch.is_tensor(fx_out):
|
if torch.is_tensor(fx_out):
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
@ -80,7 +41,30 @@ def test_torchrec_deepfm_models():
|
|||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
fx_out.values(),
|
fx_out.values(),
|
||||||
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
for key in transformed_fx_out.keys():
|
||||||
|
fx_output_val = transformed_fx_out[key]
|
||||||
|
non_fx_output_val = transformed_non_fx_out[key]
|
||||||
|
if torch.is_tensor(fx_output_val):
|
||||||
|
assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
|
||||||
|
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
|
||||||
|
else:
|
||||||
|
assert torch.allclose(fx_output_val.values(), non_fx_output_val.values()
|
||||||
|
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(NOT_DFM, reason='torchrec is not installed')
|
||||||
|
def test_torchrec_deepfm_models(deepfm_models):
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items():
|
||||||
|
data = data_gen_fn()
|
||||||
|
if attribute is not None and attribute.has_control_flow:
|
||||||
|
meta_args = {k: v.to('meta') for k, v in data.items()}
|
||||||
|
else:
|
||||||
|
meta_args = None
|
||||||
|
|
||||||
|
trace_and_compare(model_fn, data, output_transform_fn, meta_args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_torchrec_deepfm_models()
|
test_torchrec_deepfm_models(deepfm_models)
|
||||||
|
@ -1,104 +1,39 @@
|
|||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
from colossalai.fx import symbolic_trace
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
try:
|
|
||||||
from torchrec.models import dlrm
|
|
||||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
|
||||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
|
||||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
|
||||||
NOT_TORCHREC = False
|
|
||||||
except ImportError:
|
|
||||||
NOT_TORCHREC = True
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
BATCH = 2
|
BATCH = 2
|
||||||
SHAPE = 10
|
SHAPE = 10
|
||||||
|
|
||||||
|
dlrm_models = model_zoo.get_sub_registry('dlrm')
|
||||||
|
NOT_DLRM = False
|
||||||
|
if not dlrm_models:
|
||||||
|
NOT_DLRM = True
|
||||||
|
|
||||||
@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
|
|
||||||
def test_torchrec_dlrm_models():
|
|
||||||
MODEL_LIST = [
|
|
||||||
dlrm.DLRM,
|
|
||||||
dlrm.DenseArch,
|
|
||||||
dlrm.InteractionArch,
|
|
||||||
dlrm.InteractionV2Arch,
|
|
||||||
dlrm.OverArch,
|
|
||||||
dlrm.SparseArch,
|
|
||||||
# dlrm.DLRMV2
|
|
||||||
]
|
|
||||||
|
|
||||||
# Data Preparation
|
def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||||
# EmbeddingBagCollection
|
# trace
|
||||||
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
|
model = model_cls()
|
||||||
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
|
|
||||||
|
|
||||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
|
||||||
keys = ["f1", "f2"]
|
|
||||||
|
|
||||||
# KeyedTensor
|
|
||||||
KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
|
|
||||||
|
|
||||||
# KeyedJaggedTensor
|
|
||||||
KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys,
|
|
||||||
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
|
|
||||||
offsets=torch.tensor([0, 2, 4, 6, 8]))
|
|
||||||
|
|
||||||
# Dense Features
|
|
||||||
dense_features = torch.rand((BATCH, SHAPE))
|
|
||||||
|
|
||||||
# Sparse Features
|
|
||||||
sparse_features = torch.rand((BATCH, len(keys), SHAPE))
|
|
||||||
|
|
||||||
for model_cls in MODEL_LIST:
|
|
||||||
# Initializing model
|
|
||||||
if model_cls == dlrm.DLRM:
|
|
||||||
model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1])
|
|
||||||
elif model_cls == dlrm.DenseArch:
|
|
||||||
model = model_cls(SHAPE, [SHAPE, SHAPE])
|
|
||||||
elif model_cls == dlrm.InteractionArch:
|
|
||||||
model = model_cls(len(keys))
|
|
||||||
elif model_cls == dlrm.InteractionV2Arch:
|
|
||||||
I1 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE])
|
|
||||||
I2 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE])
|
|
||||||
model = model_cls(len(keys), I1, I2)
|
|
||||||
elif model_cls == dlrm.OverArch:
|
|
||||||
model = model_cls(SHAPE, [5, 1])
|
|
||||||
elif model_cls == dlrm.SparseArch:
|
|
||||||
model = model_cls(ebc)
|
|
||||||
elif model_cls == dlrm.DLRMV2:
|
|
||||||
# Currently DLRMV2 cannot be traced
|
|
||||||
model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1], [4 * SHAPE, 4 * SHAPE], [4 * SHAPE, 4 * SHAPE])
|
|
||||||
|
|
||||||
# Setup GraphModule
|
|
||||||
if model_cls == dlrm.InteractionV2Arch:
|
|
||||||
concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features}
|
|
||||||
gm = symbolic_trace(model, concrete_args=concrete_args)
|
|
||||||
else:
|
|
||||||
gm = symbolic_trace(model)
|
|
||||||
|
|
||||||
|
# convert to eval for inference
|
||||||
|
# it is important to set it to eval mode before tracing
|
||||||
|
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
gm = symbolic_trace(model, meta_args=meta_args)
|
||||||
gm.eval()
|
gm.eval()
|
||||||
|
# run forward
|
||||||
# Aligned Test
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if model_cls == dlrm.DLRM or model_cls == dlrm.DLRMV2:
|
fx_out = gm(**data)
|
||||||
fx_out = gm(dense_features, KJT)
|
non_fx_out = model(**data)
|
||||||
non_fx_out = model(dense_features, KJT)
|
|
||||||
elif model_cls == dlrm.DenseArch:
|
|
||||||
fx_out = gm(dense_features)
|
|
||||||
non_fx_out = model(dense_features)
|
|
||||||
elif model_cls == dlrm.InteractionArch or model_cls == dlrm.InteractionV2Arch:
|
|
||||||
fx_out = gm(dense_features, sparse_features)
|
|
||||||
non_fx_out = model(dense_features, sparse_features)
|
|
||||||
elif model_cls == dlrm.OverArch:
|
|
||||||
fx_out = gm(dense_features)
|
|
||||||
non_fx_out = model(dense_features)
|
|
||||||
elif model_cls == dlrm.SparseArch:
|
|
||||||
fx_out = gm(KJT)
|
|
||||||
non_fx_out = model(KJT)
|
|
||||||
|
|
||||||
|
# compare output
|
||||||
|
transformed_fx_out = output_transform_fn(fx_out)
|
||||||
|
transformed_non_fx_out = output_transform_fn(non_fx_out)
|
||||||
|
|
||||||
|
assert len(transformed_fx_out) == len(transformed_non_fx_out)
|
||||||
if torch.is_tensor(fx_out):
|
if torch.is_tensor(fx_out):
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
@ -106,7 +41,30 @@ def test_torchrec_dlrm_models():
|
|||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
fx_out.values(),
|
fx_out.values(),
|
||||||
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
for key in transformed_fx_out.keys():
|
||||||
|
fx_output_val = transformed_fx_out[key]
|
||||||
|
non_fx_output_val = transformed_non_fx_out[key]
|
||||||
|
if torch.is_tensor(fx_output_val):
|
||||||
|
assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
|
||||||
|
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
|
||||||
|
else:
|
||||||
|
assert torch.allclose(fx_output_val.values(), non_fx_output_val.values()
|
||||||
|
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(NOT_DLRM, reason='torchrec is not installed')
|
||||||
|
def test_torchrec_dlrm_models(dlrm_models):
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items():
|
||||||
|
data = data_gen_fn()
|
||||||
|
if attribute is not None and attribute.has_control_flow:
|
||||||
|
meta_args = {k: v.to('meta') for k, v in data.items()}
|
||||||
|
else:
|
||||||
|
meta_args = None
|
||||||
|
|
||||||
|
trace_and_compare(model_fn, data, output_transform_fn, meta_args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_torchrec_dlrm_models()
|
test_torchrec_dlrm_models(dlrm_models)
|
||||||
|
Loading…
Reference in New Issue
Block a user