From eaac03ae1d642e97d4e3ae2fd8f4c32c7d0407ee Mon Sep 17 00:00:00 2001 From: ExtremeViscent Date: Wed, 9 Mar 2022 01:44:20 +0000 Subject: [PATCH] [formart] format fixed for kernel\cuda_native codes (#335) --- colossalai/kernel/cuda_native/layer_norm.py | 8 +++---- .../kernel/cuda_native/multihead_attention.py | 24 +++++++++---------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py index b2ecd9ff9..af66eb827 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -37,10 +37,10 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None grad_input, grad_weight, grad_bias \ - = colossal_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) + = colossal_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) return grad_input, grad_weight, grad_bias, None, None diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py index 3e776b610..c93d1cf60 100644 --- a/colossalai/kernel/cuda_native/multihead_attention.py +++ b/colossalai/kernel/cuda_native/multihead_attention.py @@ -9,7 +9,7 @@ from torch.autograd import Function def check_config(config): if config.hidden_size % config.nhead != 0: - raise Exception(f"hidden_size % nhead != 0") + raise Exception("hidden_size % nhead != 0") factor = 8 if config.fp16 else 4 upbound = factor * 1024 * 4 @@ -215,15 +215,14 @@ class MultiHeadAttention(nn.Module): with torch.no_grad(): self.in_proj_weight.copy_( - attn_qkvw_global.view(3, hs, hs)[:, - int(hs * rank_in_pg / - self.pg_size):int(hs * (rank_in_pg + 1) / - self.pg_size), :]) + attn_qkvw_global.view(3, hs, hs)[ + :, int(hs * rank_in_pg / self.pg_size): + int(hs * (rank_in_pg + 1) / self.pg_size), + :]) self.in_proj_bias.copy_( - attn_qkvb_global.view(3, hs)[:, - int(hs * rank_in_pg / - self.pg_size):int(hs * (rank_in_pg + 1) / - self.pg_size)]) + attn_qkvb_global.view(3, hs)[ + :, int(hs * rank_in_pg / self.pg_size): + int(hs * (rank_in_pg + 1) / self.pg_size)]) attn_ow_global = torch.empty(hs, hs) nn.init.xavier_uniform_(attn_ow_global, 1.0) @@ -231,10 +230,9 @@ class MultiHeadAttention(nn.Module): torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) attn_ow_global = attn_ow_global.cpu() with torch.no_grad(): - self.out_proj_weight.copy_(attn_ow_global[:, - int(hs * rank_in_pg / - self.pg_size):int(hs * (rank_in_pg + 1) / - self.pg_size)]) + self.out_proj_weight.copy_(attn_ow_global[ + :, int(hs * rank_in_pg / self.pg_size): + int(hs * (rank_in_pg + 1) / self.pg_size)]) else: attn_qkvw = self.in_proj_weight.view(-1, hs)