[shardformer] fix linear 1d row and support uneven splits for fused qkv linear (#6084)

* [tp] hotfix linear row

* [tp] support uneven split for fused linear

* [tp] support sp for fused linear

* [tp] fix gpt2 mlp policy

* [tp] fix gather fused and add fused linear row
This commit is contained in:
Hongxin Liu
2024-10-10 14:34:45 +08:00
committed by GitHub
parent f4daf04270
commit 646b3c5a90
10 changed files with 399 additions and 157 deletions

View File

@@ -840,7 +840,7 @@ class _AllToAll(torch.autograd.Function):
ctx.gather_dim = gather_dim
ctx.fp8_communication = fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = input_.shape
bsz = input_.shape[0]
# using all_to_all_single when batch size is 1
if bsz == 1:
@@ -871,7 +871,7 @@ class _AllToAll(torch.autograd.Function):
gather_dim = ctx.scatter_dim
fp8_communication = ctx.fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = grad_output.shape
bsz = grad_output.shape[0]
if bsz == 1:
return_grad = _all_to_all_single(