mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[shardformer] Add overlap support for gpt2 (#4535)
* add overlap support for gpt2 * remove unused code * remove unused code
This commit is contained in:
@@ -291,12 +291,13 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||
ctx.dim = dim
|
||||
ctx.overlap = overlap
|
||||
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
|
||||
@@ -312,37 +313,70 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
use_bias = ctx.use_bias
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
overlap = ctx.overlap
|
||||
|
||||
# TODO: overlap SP input with gradient computation
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
if not overlap:
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
|
||||
total_input = input_parallel
|
||||
grad_input = grad_output.matmul(weight.T)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
total_input = input_parallel
|
||||
grad_input = grad_output.matmul(weight.T)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
# TODO: overlap SP input with gradient computation
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
# Asynchronous reduce-scatter
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
# Asynchronous reduce-scatter
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
|
||||
device=input_parallel.device).contiguous()
|
||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# reduce-scatter scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
handle.wait()
|
||||
|
||||
else:
|
||||
world_size = dist.get_world_size(process_group)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
|
||||
# do all gather in is async way
|
||||
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
||||
# calculate gradient and prepare data asynchronously with all-gather
|
||||
# calculate
|
||||
grad_input = grad_output.matmul(weight.T)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
# prepare data
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
|
||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# reduce-scatter scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
||||
# wait until all-gather finished
|
||||
gather_handle.wait()
|
||||
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
# do reduce-scatter in async way
|
||||
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
# calculate gradient
|
||||
if len(input_parallel.shape) > 2:
|
||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||
grad_weight = input_parallel.t().matmul(grad_output)
|
||||
# wait until reduce-scatter finished
|
||||
reducescatter_handle.wait()
|
||||
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None
|
||||
return output, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
@@ -510,9 +544,10 @@ def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
|
||||
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
|
||||
|
||||
|
||||
def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
|
||||
def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
|
||||
overlap):
|
||||
return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
|
||||
async_grad_reduce_scatter, dim)
|
||||
async_grad_reduce_scatter, dim, overlap)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, dim, process_group):
|
||||
|
Reference in New Issue
Block a user