From 521078ffc9f1a4bc8fc91e125ffd7dac683094cb Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 2 Sep 2022 15:48:35 +0800 Subject: [PATCH] [embedding] fix a bug in table wise sharding (#1538) --- .../cache_embedding/parallel_freq_aware_embedding_tablewise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py index 464efe23d..35faa67b5 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py @@ -64,7 +64,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module): self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] self.global_tables_num = len(embedding_bag_config_list) - self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0) + self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda() self.assigned_table_list: List[int] = [] for i, rank in enumerate(self.rank_of_tables):