mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
@@ -7,7 +6,7 @@ from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
from ..registry import model_zoo
|
||||
|
||||
BATCH = 2
|
||||
SHAPE = 10
|
||||
@@ -20,9 +19,9 @@ def gen_kt():
|
||||
|
||||
# KeyedJaggedTensor
|
||||
def gen_kjt():
|
||||
KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"],
|
||||
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
offsets=torch.tensor([0, 2, 4, 6, 8]))
|
||||
KJT = KeyedJaggedTensor.from_offsets_sync(
|
||||
keys=["f1", "f2"], values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), offsets=torch.tensor([0, 2, 4, 6, 8])
|
||||
)
|
||||
return KJT
|
||||
|
||||
|
||||
@@ -68,7 +67,7 @@ def get_ebc():
|
||||
# EmbeddingBagCollection
|
||||
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
|
||||
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
|
||||
return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device('cpu'))
|
||||
return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device("cpu"))
|
||||
|
||||
|
||||
def sparse_arch_model_fn():
|
||||
@@ -91,52 +90,69 @@ def dlrm_sparsearch_model_fn():
|
||||
return dlrm.SparseArch(ebc)
|
||||
|
||||
|
||||
model_zoo.register(name='deepfm_densearch',
|
||||
model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="deepfm_densearch",
|
||||
model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='deepfm_interactionarch',
|
||||
model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE),
|
||||
data_gen_fn=interaction_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="deepfm_interactionarch",
|
||||
model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE),
|
||||
data_gen_fn=interaction_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='deepfm_overarch',
|
||||
model_fn=partial(deepfm.OverArch, SHAPE),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="deepfm_overarch",
|
||||
model_fn=partial(deepfm.OverArch, SHAPE),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='deepfm_simpledeepfmnn',
|
||||
model_fn=simple_deep_fmnn_model_fn,
|
||||
data_gen_fn=simple_dfm_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="deepfm_simpledeepfmnn",
|
||||
model_fn=simple_deep_fmnn_model_fn,
|
||||
data_gen_fn=simple_dfm_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='deepfm_sparsearch',
|
||||
model_fn=sparse_arch_model_fn,
|
||||
data_gen_fn=sparse_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="deepfm_sparsearch",
|
||||
model_fn=sparse_arch_model_fn,
|
||||
data_gen_fn=sparse_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='dlrm',
|
||||
model_fn=dlrm_model_fn,
|
||||
data_gen_fn=simple_dfm_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="dlrm", model_fn=dlrm_model_fn, data_gen_fn=simple_dfm_data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
|
||||
model_zoo.register(name='dlrm_densearch',
|
||||
model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="dlrm_densearch",
|
||||
model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='dlrm_interactionarch',
|
||||
model_fn=partial(dlrm.InteractionArch, 2),
|
||||
data_gen_fn=interaction_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="dlrm_interactionarch",
|
||||
model_fn=partial(dlrm.InteractionArch, 2),
|
||||
data_gen_fn=interaction_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='dlrm_overarch',
|
||||
model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="dlrm_overarch",
|
||||
model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='dlrm_sparsearch',
|
||||
model_fn=dlrm_sparsearch_model_fn,
|
||||
data_gen_fn=sparse_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="dlrm_sparsearch",
|
||||
model_fn=dlrm_sparsearch_model_fn,
|
||||
data_gen_fn=sparse_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
Reference in New Issue
Block a user