diff --git a/colossalai/nn/_ops/__init__.py b/colossalai/nn/_ops/__init__.py index c91da3ad1..945505b74 100644 --- a/colossalai/nn/_ops/__init__.py +++ b/colossalai/nn/_ops/__init__.py @@ -5,4 +5,4 @@ from .loss import colo_cross_entropy from .embedding import colo_embedding from .addmm import colo_addmm from .embedding_bag import colo_embedding_bag -from .view import colo_view +from .view import colo_view \ No newline at end of file diff --git a/colossalai/nn/parallel/layers/__init__.py b/colossalai/nn/parallel/layers/__init__.py index 6d205828d..0ebadac6c 100644 --- a/colossalai/nn/parallel/layers/__init__.py +++ b/colossalai/nn/parallel/layers/__init__.py @@ -3,13 +3,10 @@ from .linear import ColoLinear from .embedding import ColoEmbedding from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module +from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer + __all__ = [ - 'ColoModule', - 'register_colo_module', - 'is_colo_module', - 'get_colo_module', - 'init_colo_module', - 'check_colo_module', - 'ColoLinear', - 'ColoEmbedding', + 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', + 'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr', + 'LimitBuffIndexCopyer' ] diff --git a/colossalai/nn/_ops/cache_embedding/__init__.py b/colossalai/nn/parallel/layers/cache_embedding/__init__.py similarity index 100% rename from colossalai/nn/_ops/cache_embedding/__init__.py rename to colossalai/nn/parallel/layers/cache_embedding/__init__.py diff --git a/colossalai/nn/_ops/cache_embedding/base_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py similarity index 100% rename from colossalai/nn/_ops/cache_embedding/base_embedding.py rename to colossalai/nn/parallel/layers/cache_embedding/base_embedding.py diff --git a/colossalai/nn/_ops/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py similarity index 100% rename from colossalai/nn/_ops/cache_embedding/cache_mgr.py rename to colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py diff --git a/colossalai/nn/_ops/cache_embedding/copyer.py b/colossalai/nn/parallel/layers/cache_embedding/copyer.py similarity index 100% rename from colossalai/nn/_ops/cache_embedding/copyer.py rename to colossalai/nn/parallel/layers/cache_embedding/copyer.py diff --git a/colossalai/nn/_ops/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py similarity index 100% rename from colossalai/nn/_ops/cache_embedding/freq_aware_embedding.py rename to colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py diff --git a/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py similarity index 97% rename from colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py rename to colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py index 5fb70fc2e..ee751435a 100644 --- a/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py @@ -5,9 +5,9 @@ from typing import List, Optional, Iterator, Tuple from .base_embedding import BaseEmbeddingBag from .cache_mgr import CachedParamMgr from torch.nn.parameter import Parameter -from .._utils import dual_all_to_all +from colossalai.nn._ops._utils import dual_all_to_all -from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup, ColoTensorSpec +from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: diff --git a/tests/test_ops/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py similarity index 98% rename from tests/test_ops/test_cache_embedding.py rename to tests/test_layers/test_cache_embedding.py index 8471975df..d7f6e7ee7 100644 --- a/tests/test_ops/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -1,15 +1,17 @@ import pytest from functools import partial -import torch -import torch.multiprocessing as mp + import numpy as np import random +import torch +import torch.multiprocessing as mp + import colossalai from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec -from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag +from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag NUM_EMBED, EMBED_DIM = 10, 8 BATCH_SIZE = 8