mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[shardformer] fix linear 1d row and support uneven splits for fused qkv linear (#6084)
* [tp] hotfix linear row * [tp] support uneven split for fused linear * [tp] support sp for fused linear * [tp] fix gpt2 mlp policy * [tp] fix gather fused and add fused linear row
This commit is contained in:
@@ -840,7 +840,7 @@ class _AllToAll(torch.autograd.Function):
|
||||
ctx.gather_dim = gather_dim
|
||||
ctx.fp8_communication = fp8_communication
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = input_.shape
|
||||
bsz = input_.shape[0]
|
||||
|
||||
# using all_to_all_single when batch size is 1
|
||||
if bsz == 1:
|
||||
@@ -871,7 +871,7 @@ class _AllToAll(torch.autograd.Function):
|
||||
gather_dim = ctx.scatter_dim
|
||||
fp8_communication = ctx.fp8_communication
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = grad_output.shape
|
||||
bsz = grad_output.shape[0]
|
||||
|
||||
if bsz == 1:
|
||||
return_grad = _all_to_all_single(
|
||||
|
Reference in New Issue
Block a user