mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 20:23:41 +00:00
[formart] format fixed for kernel\cuda_native codes (#335)
This commit is contained in:
parent
00670c870e
commit
eaac03ae1d
@ -9,7 +9,7 @@ from torch.autograd import Function
|
|||||||
|
|
||||||
def check_config(config):
|
def check_config(config):
|
||||||
if config.hidden_size % config.nhead != 0:
|
if config.hidden_size % config.nhead != 0:
|
||||||
raise Exception(f"hidden_size % nhead != 0")
|
raise Exception("hidden_size % nhead != 0")
|
||||||
|
|
||||||
factor = 8 if config.fp16 else 4
|
factor = 8 if config.fp16 else 4
|
||||||
upbound = factor * 1024 * 4
|
upbound = factor * 1024 * 4
|
||||||
@ -215,15 +215,14 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.in_proj_weight.copy_(
|
self.in_proj_weight.copy_(
|
||||||
attn_qkvw_global.view(3, hs, hs)[:,
|
attn_qkvw_global.view(3, hs, hs)[
|
||||||
int(hs * rank_in_pg /
|
:, int(hs * rank_in_pg / self.pg_size):
|
||||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
int(hs * (rank_in_pg + 1) / self.pg_size),
|
||||||
self.pg_size), :])
|
:])
|
||||||
self.in_proj_bias.copy_(
|
self.in_proj_bias.copy_(
|
||||||
attn_qkvb_global.view(3, hs)[:,
|
attn_qkvb_global.view(3, hs)[
|
||||||
int(hs * rank_in_pg /
|
:, int(hs * rank_in_pg / self.pg_size):
|
||||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
int(hs * (rank_in_pg + 1) / self.pg_size)])
|
||||||
self.pg_size)])
|
|
||||||
|
|
||||||
attn_ow_global = torch.empty(hs, hs)
|
attn_ow_global = torch.empty(hs, hs)
|
||||||
nn.init.xavier_uniform_(attn_ow_global, 1.0)
|
nn.init.xavier_uniform_(attn_ow_global, 1.0)
|
||||||
@ -231,10 +230,9 @@ class MultiHeadAttention(nn.Module):
|
|||||||
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
|
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
|
||||||
attn_ow_global = attn_ow_global.cpu()
|
attn_ow_global = attn_ow_global.cpu()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.out_proj_weight.copy_(attn_ow_global[:,
|
self.out_proj_weight.copy_(attn_ow_global[
|
||||||
int(hs * rank_in_pg /
|
:, int(hs * rank_in_pg / self.pg_size):
|
||||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
int(hs * (rank_in_pg + 1) / self.pg_size)])
|
||||||
self.pg_size)])
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
attn_qkvw = self.in_proj_weight.view(-1, hs)
|
attn_qkvw = self.in_proj_weight.view(-1, hs)
|
||||||
|
Loading…
Reference in New Issue
Block a user