diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index 2be3b45d0..f437c44e0 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -525,7 +525,7 @@ class VocabParallelClassifier3D(ParallelLayer): def _set_tensor_parallel_attributes(self) -> None: if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2) + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.depth) @@ -1048,7 +1048,7 @@ class VocabParallelEmbedding3D(torch.nn.Module): env.vocab_parallel = True def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2) + set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3) def reset_parameters(self, weight_initializer) -> None: with seed(ParallelMode.TENSOR):