[shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
Zhongkai Zhao
2024-04-03 17:15:47 +08:00
committed by GitHub
parent 7e0ec5a85c
commit 8e412a548e
33 changed files with 1630 additions and 256 deletions

View File

@@ -167,6 +167,97 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
cur_rank = dist.get_rank(process_group)
# output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
# initialization of ring communication
recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
rank_map = list(dist.get_process_group_ranks(process_group))
recv_rank = rank_map[recv_rank]
send_rank = rank_map[send_rank]
recv_tensors = {}
send_tensors = {}
for k, v in input_to_gather.items():
recv_tensors[k] = torch.empty_like(v)
send_tensors[k] = v.clone()
def communicate_step():
comm_ops = []
for k in recv_tensors:
comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group))
comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group))
return dist.batch_isend_irecv(comm_ops)
def switch_step():
for k in recv_tensors:
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
output_tensors = []
handles = communicate_step()
# first round: special case, retrive from local tensor
output_tensors.append(func(**input_to_gather, **input_local))
for i in range(group_size - 2):
for handle in handles:
handle.wait()
switch_step()
handles = communicate_step()
# actual computation
output_tensors.append(func(**send_tensors, **input_local))
# final round: special case, no need to send/recv again
for handle in handles:
handle.wait()
output_tensors.append(func(**recv_tensors, **input_local))
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
# do reduce-scatter
new_shape = list(grad_output.shape)
assert (
new_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
grad_list = [
item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
dist.reduce_scatter(output, grad_list, group=process_group)
return output, None, None
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
@@ -178,7 +269,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
@@ -186,12 +277,25 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
ctx.dim = dim
ctx.overlap = overlap
input_parallel = _gather(input_, dim, process_group)
if ring is True:
input_to_gather = {"input": input_}
input_local = {"weight": weight}
if bias is not None:
output = F.linear(input_parallel, weight, bias)
output = _ring_as_gather(
F.linear,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
)
if bias is not None:
output += bias
else:
output = F.linear(input_parallel, weight)
input_parallel = _gather(input_, dim, process_group)
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
output = F.linear(input_parallel, weight)
return output
@@ -294,11 +398,146 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None
def _ring_as_reducescatter(
func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1
):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
cur_rank = dist.get_rank(process_group)
# initialization of ring communication
recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
rank_map = list(dist.get_process_group_ranks(process_group))
recv_rank = rank_map[recv_rank]
send_rank = rank_map[send_rank]
input_tensors = []
for _ in range(group_size):
input_tensors.append({})
for k, v in input_to_reducescatter.items():
input_shape = v.shape
assert input_shape[reducescatter_dim] % group_size == 0
_input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim))
for i in range(group_size):
input_tensors[i][k] = _input_tensors[i]
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
input_tensors.reverse()
output_tensor = func(**input_tensors[0], **input_local)
recv_tensor = torch.empty_like(output_tensor)
send_tensor = output_tensor.clone()
def communicate_step():
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
return dist.batch_isend_irecv([recv_op, send_op])
handles = communicate_step()
# first round: special case, retrive from local tensor
for i in range(group_size - 2):
# actual computation
output_tensor = func(**input_tensors[i + 1], **input_local)
for handle in handles:
handle.wait()
output_tensor += recv_tensor
tmp_tensor = send_tensor
send_tensor = output_tensor
output_tensor = tmp_tensor
handles = communicate_step()
# final round: special case, no need to send/recv again
output_tensor = func(**input_tensors[-1], **input_local)
for handle in handles:
handle.wait()
output_tensor += recv_tensor
return output_tensor
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
"""Reduce-scatter input from sequence parallel in forward and gather gradient in backward with ring
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, dim, ring):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.dim = dim
if ring is True:
input_to_reducescatter = {"input": input_}
input_local = {"weight": weight}
if bias is not None:
input_to_reducescatter["bias"] = bias
output = _ring_as_reducescatter(
F.linear,
input_to_reducescatter=input_to_reducescatter,
input_local=input_local,
process_group=process_group,
)
else:
if bias is not None:
partial_output = F.linear(input_, weight, bias)
else:
partial_output = F.linear(input_, weight)
output_shape = list(partial_output.shape)
assert (
output_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
output_list = [
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
dist.reduce_scatter(output, output_list, group=process_group)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
grad_output = _gather(grad_output, dim, process_group)
# TODO Need to fully optimize
total_input = input_
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None
class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Reduce-scatter input from sequence parallel in forward and gather gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
@@ -343,7 +582,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
@@ -351,9 +590,24 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
ctx.dim = dim
ctx.overlap = overlap
input_parallel = _gather(input_, dim, process_group)
if ring is True:
input_to_gather = {}
input_local = {}
input_to_gather["input"] = input_
input_local["other"] = weight
output = torch.matmul(input_parallel, weight)
output = _ring_as_gather(
torch.matmul,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
gather_dim=dim,
)
else:
input_parallel = _gather(input_, dim, process_group)
output = torch.matmul(input_parallel, weight)
if bias is not None:
output = output + bias
@@ -433,7 +687,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
@@ -448,14 +702,17 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, dim, process_group):
def forward(ctx, input_, dim, process_group, grad_scale=None):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
return _split(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None
class _ReduceForward(torch.autograd.Function):
@@ -505,14 +762,50 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, dim, process_group):
def forward(ctx, input_, dim, process_group, grad_scale=None):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
return _gather(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
world_size = dist.get_world_size(process_group)
bsz, _, _ = input_.shape
# using all_to_all_single when batch size is 1
if bsz == 1:
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
else:
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)
@staticmethod
def backward(ctx, *grad_output):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
return (return_grad, None, None, None)
class HookParameter(torch.autograd.Function):
@@ -608,6 +901,40 @@ def _reduce_scatter(input_, dim=1, process_group=None):
return output
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
inp_shape = list(input_.shape)
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
if scatter_dim < 2:
input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous()
else:
input_t = (
input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :])
.transpose(0, 1)
.contiguous()
)
output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)
if scatter_dim < 2:
output = output.transpose(0, 1).contiguous()
return output.reshape(
inp_shape[:gather_dim]
+ [
inp_shape[gather_dim] * seq_world_size,
]
+ inp_shape[gather_dim + 1 :]
).contiguous()
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
@@ -617,31 +944,39 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
return _LinearWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
)
def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
def gather_forward_reducescatter_backward(input_, process_group, dim):
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
def reducescatter_forward_gather_backward(input_, process_group, dim):
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)
def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
)
def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)
def split_forward_gather_backward(input_, dim, process_group):
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
def reduce_forward(input_, process_group):
@@ -650,3 +985,7 @@ def reduce_forward(input_, process_group):
def reduce_backward(input_, process_group):
return _ReduceBackward.apply(input_, process_group)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)