mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[shardformer/sequence parallel] support gpt2 seq parallel with pp/dp/tp (#4460)
* support gpt2 seq parallel with pp/dp/tp * fix a bug when waiting for stream done * delete unused gpt2_seq file
This commit is contained in:
@@ -239,6 +239,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
||||
|
||||
torch.cuda.current_stream().wait_stream(calculate_stream)
|
||||
gather_handle.wait()
|
||||
|
||||
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
with torch.cuda.stream(calculate_stream):
|
||||
@@ -249,6 +250,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
grad_weight = grad_output.t().matmul(input_parallel)
|
||||
|
||||
torch.cuda.current_stream().wait_stream(calculate_stream)
|
||||
reducescatter_handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
Reference in New Issue
Block a user