mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)
* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384) * [sequence parallel] add sequence parallel linear col/row support (#4336) * add sequence parallel linear col/row support * add annotation * add annotation * add support for gpt2 fused qkv linear layer * support sequence parallel in GPT2 * add docstring and note * add requirments * remove unused flash-attb * modify flash attn test * modify flash attn setting * modify flash attn code * add assert before divide, rename forward function * [shardformer/test] fix gpt2 test with seq-parallel * [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401) * overlap gather input / grad computing during col backward * modify test for overlap * simplify code * fix code and modify cuda stream synchronize * [shardformer/sequence parallel] polish code
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
@@ -141,6 +143,215 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
|
||||
|
||||
Args:
|
||||
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
|
||||
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||
ctx.dim = dim
|
||||
ctx.overlap = overlap
|
||||
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
|
||||
if bias is not None:
|
||||
output = F.linear(input_parallel, weight, bias)
|
||||
else:
|
||||
output = F.linear(input_parallel, weight)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_, weight = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
overlap = ctx.overlap
|
||||
|
||||
if not overlap:
|
||||
# TODO: overlap SP input with gradient computation
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
|
||||
total_input = input_parallel
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
# TODO: overlap SP input with gradient computation
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
# Asynchronous reduce-scatter
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
|
||||
device=input_parallel.device).contiguous()
|
||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# reduce-scatter scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
handle.wait()
|
||||
|
||||
else:
|
||||
# create new stream for calculate the gradient
|
||||
calculate_stream = torch.cuda.Stream()
|
||||
|
||||
# do all gather in default stream
|
||||
input_ = input_.contiguous()
|
||||
world_size = dist.get_world_size(process_group)
|
||||
rank = dist.get_rank(process_group)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
||||
|
||||
# calculate gradient in calculate_stream
|
||||
with torch.cuda.stream(calculate_stream):
|
||||
# calculate
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
# prepare data
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
||||
|
||||
torch.cuda.current_stream().wait_stream(calculate_stream)
|
||||
|
||||
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
with torch.cuda.stream(calculate_stream):
|
||||
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
if len(input_parallel.shape) > 2:
|
||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||
print(grad_output.shape, input_parallel.shape)
|
||||
grad_weight = grad_output.t().matmul(input_parallel)
|
||||
|
||||
torch.cuda.current_stream().wait_stream(calculate_stream)
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
|
||||
|
||||
Args:
|
||||
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim):
|
||||
ctx.dim = dim
|
||||
ctx.process_group = process_group
|
||||
|
||||
# do reduce-scatter
|
||||
new_shape = list(input_.shape)
|
||||
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
|
||||
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
|
||||
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
|
||||
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
|
||||
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
||||
dist.reduce_scatter(output, input_list, group=process_group)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
|
||||
return _gather(grad_output, dim, process_group), None, None
|
||||
|
||||
|
||||
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
"""
|
||||
This class is designed for matmul operation with gather forward and reduce-scatter backward.
|
||||
|
||||
Args:
|
||||
input_ (`torch.Tensor`): input matrix.
|
||||
dim (int): the dimension to perform split and gather
|
||||
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||
ctx.dim = dim
|
||||
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
|
||||
output = torch.matmul(input_parallel, weight)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_, weight = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
|
||||
# TODO: overlap SP input with gradient computation
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
|
||||
total_input = input_parallel
|
||||
grad_input = grad_output.matmul(weight.T)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
# TODO: overlap SP input with gradient computation
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
# Asynchronous reduce-scatter
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
|
||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# reduce-scatter scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
Split the input and keep only the corresponding chuck to the rank.
|
||||
@@ -200,6 +411,26 @@ class _ReduceBackward(torch.autograd.Function):
|
||||
return _reduce(grad_output, ctx.process_group), None
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""Gather the input from model parallel region and concatenate.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
parallel_mode: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, process_group):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
return _gather(input_, dim, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None
|
||||
|
||||
|
||||
def _reduce(input_, process_group):
|
||||
# skip if only one rank involved
|
||||
if dist.get_world_size(process_group) == 1:
|
||||
@@ -235,6 +466,7 @@ def _gather(input_, dim=-1, process_group=None):
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
input_ = input_.contiguous()
|
||||
rank = dist.get_rank(process_group)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
@@ -246,24 +478,27 @@ def _gather(input_, dim=-1, process_group=None):
|
||||
return output
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""Gather the input from model parallel region and concatenate.
|
||||
def _reduce_scatter(input_, dim=1, process_group=None):
|
||||
""" Do reduce-scatter operation.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
parallel_mode: parallel mode.
|
||||
dim: dimension
|
||||
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||
dim (int): The dimension to perform reduce-scatter.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
|
||||
"""
|
||||
world_size = dist.get_world_size(process_group)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, process_group):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
return _gather(input_, dim, process_group)
|
||||
# reduce-scatter
|
||||
new_shape = list(input_.shape)
|
||||
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
|
||||
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
|
||||
new_shape[dim] = new_shape[dim] // world_size
|
||||
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
||||
dist.reduce_scatter(output, input_, group=process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None
|
||||
return output
|
||||
|
||||
|
||||
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||
@@ -274,6 +509,21 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
|
||||
|
||||
def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
|
||||
overlap):
|
||||
return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
|
||||
async_grad_reduce_scatter, dim, overlap)
|
||||
|
||||
|
||||
def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
|
||||
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
|
||||
|
||||
|
||||
def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
|
||||
return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
|
||||
async_grad_reduce_scatter, dim)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, dim, process_group):
|
||||
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
|
||||
|
||||
|
Reference in New Issue
Block a user