From 3fc8a204dc54032508606c82e3fe9f12a6f9afe0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?= Date: Mon, 11 Apr 2022 10:17:55 +0800 Subject: [PATCH] []Corrected 3d vocab parallel embedding (#707) --- colossalai/nn/layer/parallel_3d/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index 2be3b45d0..f437c44e0 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -525,7 +525,7 @@ class VocabParallelClassifier3D(ParallelLayer): def _set_tensor_parallel_attributes(self) -> None: if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2) + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.depth) @@ -1048,7 +1048,7 @@ class VocabParallelEmbedding3D(torch.nn.Module): env.vocab_parallel = True def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2) + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) def reset_parameters(self, weight_initializer) -> None: with seed(ParallelMode.TENSOR):