[shardformer/sequence parallel] support gpt2 seq parallel with pp/dp/tp (#4460)

* support gpt2 seq parallel with pp/dp/tp

* fix a bug when waiting for stream done

* delete unused gpt2_seq file
This commit is contained in:
Bin Jia
2023-08-18 11:21:53 +08:00
committed by GitHub
parent a78daf6180
commit 7c8be77081
6 changed files with 268 additions and 240 deletions

View File

@@ -239,6 +239,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
torch.cuda.current_stream().wait_stream(calculate_stream)
gather_handle.wait()
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
with torch.cuda.stream(calculate_stream):
@@ -249,6 +250,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
grad_weight = grad_output.t().matmul(input_parallel)
torch.cuda.current_stream().wait_stream(calculate_stream)
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None