mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516)
* fix overlap bug and support bert, add overlap as an option in shardconfig * support overlap for chatglm and bloom
This commit is contained in:
@@ -211,43 +211,36 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
handle.wait()
|
||||
|
||||
else:
|
||||
# create new stream for calculate the gradient
|
||||
calculate_stream = torch.cuda.Stream()
|
||||
|
||||
# do all gather in default stream
|
||||
input_ = input_.contiguous()
|
||||
world_size = dist.get_world_size(process_group)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
|
||||
# do all gather in is async way
|
||||
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
||||
|
||||
# calculate gradient in calculate_stream
|
||||
with torch.cuda.stream(calculate_stream):
|
||||
# calculate
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
# prepare data
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
||||
|
||||
torch.cuda.current_stream().wait_stream(calculate_stream)
|
||||
# calculate gradient and prepare data asynchronously with all-gather
|
||||
# calculate
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
# prepare data
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
||||
# wait until all-gather finished
|
||||
gather_handle.wait()
|
||||
|
||||
# do reduce-scatter in async way
|
||||
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
with torch.cuda.stream(calculate_stream):
|
||||
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
if len(input_parallel.shape) > 2:
|
||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||
print(grad_output.shape, input_parallel.shape)
|
||||
grad_weight = grad_output.t().matmul(input_parallel)
|
||||
|
||||
torch.cuda.current_stream().wait_stream(calculate_stream)
|
||||
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
# calculate gradient
|
||||
if len(input_parallel.shape) > 2:
|
||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||
grad_weight = grad_output.t().matmul(input_parallel)
|
||||
# wait until reduce-scatter finished
|
||||
reducescatter_handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None
|
||||
|
Reference in New Issue
Block a user