mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[fx] Add colotracer compatibility test on torchrec (#1370)
This commit is contained in:
@@ -3,3 +3,4 @@ torchvision
|
|||||||
transformers
|
transformers
|
||||||
timm
|
timm
|
||||||
titans
|
titans
|
||||||
|
torchrec
|
||||||
|
@@ -0,0 +1,91 @@
|
|||||||
|
from curses import meta
|
||||||
|
from math import dist
|
||||||
|
from xml.dom import HierarchyRequestErr
|
||||||
|
from colossalai.fx.tracer import meta_patch
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
||||||
|
import torch
|
||||||
|
from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor
|
||||||
|
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||||
|
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||||
|
from torchrec.models import deepfm, dlrm
|
||||||
|
import colossalai.fx as fx
|
||||||
|
import pdb
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
BATCH = 2
|
||||||
|
SHAPE = 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_torchrec_deepfm_models():
|
||||||
|
MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch]
|
||||||
|
|
||||||
|
# Data Preparation
|
||||||
|
# EmbeddingBagCollection
|
||||||
|
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
|
||||||
|
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
|
||||||
|
|
||||||
|
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||||
|
keys = ["f1", "f2"]
|
||||||
|
|
||||||
|
# KeyedTensor
|
||||||
|
KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
|
||||||
|
|
||||||
|
# KeyedJaggedTensor
|
||||||
|
KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys,
|
||||||
|
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
|
||||||
|
offsets=torch.tensor([0, 2, 4, 6, 8]))
|
||||||
|
|
||||||
|
# Dense Features
|
||||||
|
features = torch.rand((BATCH, SHAPE))
|
||||||
|
|
||||||
|
# Tracer
|
||||||
|
tracer = ColoTracer()
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
# Initializing model
|
||||||
|
if model_cls == deepfm.DenseArch:
|
||||||
|
model = model_cls(SHAPE, SHAPE, SHAPE)
|
||||||
|
elif model_cls == deepfm.FMInteractionArch:
|
||||||
|
model = model_cls(SHAPE * 3, keys, SHAPE)
|
||||||
|
elif model_cls == deepfm.OverArch:
|
||||||
|
model = model_cls(SHAPE)
|
||||||
|
elif model_cls == deepfm.SimpleDeepFMNN:
|
||||||
|
model = model_cls(SHAPE, ebc, SHAPE, SHAPE)
|
||||||
|
elif model_cls == deepfm.SparseArch:
|
||||||
|
model = model_cls(ebc)
|
||||||
|
|
||||||
|
# Setup GraphModule
|
||||||
|
graph = tracer.trace(model)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
gm.eval()
|
||||||
|
|
||||||
|
# Aligned Test
|
||||||
|
with torch.no_grad():
|
||||||
|
if model_cls == deepfm.DenseArch or model_cls == deepfm.OverArch:
|
||||||
|
fx_out = gm(features)
|
||||||
|
non_fx_out = model(features)
|
||||||
|
elif model_cls == deepfm.FMInteractionArch:
|
||||||
|
fx_out = gm(features, KT)
|
||||||
|
non_fx_out = model(features, KT)
|
||||||
|
elif model_cls == deepfm.SimpleDeepFMNN:
|
||||||
|
fx_out = gm(features, KJT)
|
||||||
|
non_fx_out = model(features, KJT)
|
||||||
|
elif model_cls == deepfm.SparseArch:
|
||||||
|
fx_out = gm(KJT)
|
||||||
|
non_fx_out = model(KJT)
|
||||||
|
|
||||||
|
if torch.is_tensor(fx_out):
|
||||||
|
assert torch.allclose(
|
||||||
|
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
else:
|
||||||
|
assert torch.allclose(
|
||||||
|
fx_out.values(),
|
||||||
|
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_torchrec_deepfm_models()
|
116
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
Normal file
116
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
from curses import meta
|
||||||
|
from math import dist
|
||||||
|
from xml.dom import HierarchyRequestErr
|
||||||
|
from colossalai.fx.tracer import meta_patch
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
||||||
|
import torch
|
||||||
|
from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor
|
||||||
|
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||||
|
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||||
|
from torchrec.models import deepfm, dlrm
|
||||||
|
import colossalai.fx as fx
|
||||||
|
import pdb
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
BATCH = 2
|
||||||
|
SHAPE = 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_torchrec_dlrm_models():
|
||||||
|
MODEL_LIST = [
|
||||||
|
dlrm.DLRM,
|
||||||
|
dlrm.DenseArch,
|
||||||
|
dlrm.InteractionArch,
|
||||||
|
dlrm.InteractionV2Arch,
|
||||||
|
dlrm.OverArch,
|
||||||
|
dlrm.SparseArch,
|
||||||
|
# dlrm.DLRMV2
|
||||||
|
]
|
||||||
|
|
||||||
|
# Data Preparation
|
||||||
|
# EmbeddingBagCollection
|
||||||
|
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
|
||||||
|
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
|
||||||
|
|
||||||
|
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||||
|
keys = ["f1", "f2"]
|
||||||
|
|
||||||
|
# KeyedTensor
|
||||||
|
KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
|
||||||
|
|
||||||
|
# KeyedJaggedTensor
|
||||||
|
KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys,
|
||||||
|
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
|
||||||
|
offsets=torch.tensor([0, 2, 4, 6, 8]))
|
||||||
|
|
||||||
|
# Dense Features
|
||||||
|
dense_features = torch.rand((BATCH, SHAPE))
|
||||||
|
|
||||||
|
# Sparse Features
|
||||||
|
sparse_features = torch.rand((BATCH, len(keys), SHAPE))
|
||||||
|
# Tracer
|
||||||
|
tracer = ColoTracer()
|
||||||
|
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
# Initializing model
|
||||||
|
if model_cls == dlrm.DLRM:
|
||||||
|
model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1])
|
||||||
|
elif model_cls == dlrm.DenseArch:
|
||||||
|
model = model_cls(SHAPE, [SHAPE, SHAPE])
|
||||||
|
elif model_cls == dlrm.InteractionArch:
|
||||||
|
model = model_cls(len(keys))
|
||||||
|
elif model_cls == dlrm.InteractionV2Arch:
|
||||||
|
I1 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE])
|
||||||
|
I2 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE])
|
||||||
|
model = model_cls(len(keys), I1, I2)
|
||||||
|
elif model_cls == dlrm.OverArch:
|
||||||
|
model = model_cls(SHAPE, [5, 1])
|
||||||
|
elif model_cls == dlrm.SparseArch:
|
||||||
|
model = model_cls(ebc)
|
||||||
|
elif model_cls == dlrm.DLRMV2:
|
||||||
|
# Currently DLRMV2 cannot be traced
|
||||||
|
model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1], [4 * SHAPE, 4 * SHAPE], [4 * SHAPE, 4 * SHAPE])
|
||||||
|
|
||||||
|
# Setup GraphModule
|
||||||
|
if model_cls == dlrm.InteractionV2Arch:
|
||||||
|
concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features}
|
||||||
|
graph = tracer.trace(model, concrete_args=concrete_args)
|
||||||
|
else:
|
||||||
|
graph = tracer.trace(model)
|
||||||
|
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
gm.eval()
|
||||||
|
|
||||||
|
# Aligned Test
|
||||||
|
with torch.no_grad():
|
||||||
|
if model_cls == dlrm.DLRM or model_cls == dlrm.DLRMV2:
|
||||||
|
fx_out = gm(dense_features, KJT)
|
||||||
|
non_fx_out = model(dense_features, KJT)
|
||||||
|
elif model_cls == dlrm.DenseArch:
|
||||||
|
fx_out = gm(dense_features)
|
||||||
|
non_fx_out = model(dense_features)
|
||||||
|
elif model_cls == dlrm.InteractionArch or model_cls == dlrm.InteractionV2Arch:
|
||||||
|
fx_out = gm(dense_features, sparse_features)
|
||||||
|
non_fx_out = model(dense_features, sparse_features)
|
||||||
|
elif model_cls == dlrm.OverArch:
|
||||||
|
fx_out = gm(dense_features)
|
||||||
|
non_fx_out = model(dense_features)
|
||||||
|
elif model_cls == dlrm.SparseArch:
|
||||||
|
fx_out = gm(KJT)
|
||||||
|
non_fx_out = model(KJT)
|
||||||
|
|
||||||
|
if torch.is_tensor(fx_out):
|
||||||
|
assert torch.allclose(
|
||||||
|
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
else:
|
||||||
|
assert torch.allclose(
|
||||||
|
fx_out.values(),
|
||||||
|
non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_torchrec_dlrm_models()
|
Reference in New Issue
Block a user