[shardformer] optimize seq parallelism (#6086)

* [shardformer] optimize seq parallelism

* [shardformer] fix gpt2 fused linear col

* [plugin] update gemini plugin

* [plugin] update moe hybrid plugin

* [test] update gpt2 fused linear test

* [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
Hongxin Liu
2024-10-11 13:44:40 +08:00
committed by GitHub
parent 6b2c506fc5
commit dc2cdaf3e8
13 changed files with 111 additions and 278 deletions

View File

@@ -102,7 +102,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce and fp8_communication:
if fp8_communication or not ctx.async_grad_allreduce:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
elif ctx.async_grad_allreduce:
# Asynchronous all-reduce
@@ -216,10 +216,12 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
for k in recv_tensors:
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
input_tensors = []
output_tensors = []
handles = communicate_step()
# first round: special case, retrive from local tensor
input_tensors.append(input_to_gather)
output_tensors.append(func(**input_to_gather, **input_local))
for i in range(group_size - 2):
for handle in handles:
@@ -230,14 +232,25 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
handles = communicate_step()
# actual computation
input_tensors.append(send_tensors)
output_tensors.append(func(**send_tensors, **input_local))
# final round: special case, no need to send/recv again
for handle in handles:
handle.wait()
input_tensors.append(send_tensors)
output_tensors.append(func(**recv_tensors, **input_local))
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
gathered_input = {}
for k in input_to_gather:
input_shards = [d[k] for d in input_tensors[group_size - cur_rank :] + input_tensors[: group_size - cur_rank]]
gathered_input[k] = torch.cat(input_shards, dim=gather_dim)
gathered_output = torch.cat(
output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim
)
return gathered_output, gathered_input
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
@@ -293,29 +306,30 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
if ring is True:
input_to_gather = {"input": input_}
input_local = {"weight": weight}
output = _ring_as_gather(
output, input_dict = _ring_as_gather(
F.linear,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
)
ctx.gathered_input = input_dict["input"]
if bias is not None:
output += bias
else:
input_parallel = _gather(input_, dim, process_group)
ctx.gathered_input = input_parallel
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
@@ -329,100 +343,50 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
if not overlap:
input_parallel = _gather(input_, dim, process_group)
input_parallel = ctx.gathered_input
total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input_parallel)
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input_parallel)
# grad_weight = grad_output.t().matmul(input_parallel)
# wait until reduce-scatter finished
reducescatter_handle.wait()
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
return output, grad_weight, grad_bias, None, None, None, None, None
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
def _ring_as_reducescatter(
@@ -553,7 +517,7 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
total_input = total_input.reshape(-1, total_input.shape[-1])
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
@@ -611,34 +575,30 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
ctx.fp8_communication = fp8_communication
if ring is True:
input_to_gather = {}
input_local = {}
input_to_gather["input"] = input_
input_local["other"] = weight
input_to_gather = {"input": input_}
input_local = {"other": weight}
output = _ring_as_gather(
output, input_dict = _ring_as_gather(
torch.matmul,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
gather_dim=dim,
)
ctx.gathered_input = input_dict["input"]
else:
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
ctx.gathered_input = input_parallel
output = torch.matmul(input_parallel, weight)
if bias is not None:
@@ -651,76 +611,39 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape)
if use_bias:
bias = bias.view(bias.shape)
if not overlap:
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
input_parallel = ctx.gathered_input
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished
reducescatter_handle.wait()
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
return output, grad_weight, grad_bias, None, None, None, None, None, None
if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
@@ -1050,10 +973,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
):
return _LinearWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring
)
@@ -1070,10 +993,10 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication
)