From c9427a323f8c4a67f1a528a5de6bc6e9551dce9a Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 11 Aug 2022 13:14:25 +0800 Subject: [PATCH] hotfix #1434 (#1437) --- .../_ops/cache_embedding/parallel_freq_aware_embedding.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py index 083076532..5fb70fc2e 100644 --- a/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py @@ -7,7 +7,7 @@ from .cache_mgr import CachedParamMgr from torch.nn.parameter import Parameter from .._utils import dual_all_to_all -from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup +from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup, ColoTensorSpec def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: @@ -57,13 +57,15 @@ class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag): self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index if _weight is None: - self._weight.process_group = ProcessGroup(tp_degree=self.world_size) + colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size), + dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]), + compute_attr=ComputePattern.TP1D) self._weight = ColoParameter.from_torch_tensor(torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device='cpu', dtype=dtype), requires_grad=True, - spec=ShardSpec(dims=[-1], num_partitions=[self.world_size])) + spec=colo_tensor_spec) self.init_parameters() else: assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter"