diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index 654d5d07f..33f358241 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -53,8 +53,8 @@ class LayerNorm3D(ParallelLayer): self.weight = Parameter( torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) if bias: - self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition, - device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) else: self.bias = None self.variance_epsilon = eps @@ -854,7 +854,7 @@ class PatchEmbedding3D(ParallelLayer): input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1)