From 457a0de79fd2d3602eba0ac78e606acb6401fc60 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Mon, 8 Jul 2024 07:04:48 +0000 Subject: [PATCH 1/3] shardformer fp8 --- .../booster/plugin/hybrid_parallel_plugin.py | 4 +- colossalai/params.py | 1 + colossalai/quantization/fp8.py | 42 ++- colossalai/shardformer/layer/_operation.py | 290 +++++++++++++----- colossalai/shardformer/layer/linear.py | 18 +- .../shardformer/layer/qkv_fused_linear.py | 43 ++- colossalai/shardformer/modeling/gpt2.py | 2 + colossalai/shardformer/modeling/llama.py | 26 +- colossalai/shardformer/policies/gpt2.py | 10 +- colossalai/shardformer/policies/llama.py | 14 +- colossalai/shardformer/shard/shard_config.py | 2 + examples/language/bert/finetune.py | 5 +- examples/language/bert/test_ci.sh | 2 +- .../gpt/hybridparallelism/finetune.py | 11 +- .../test_model/test_shard_gpt2.py | 78 +++-- .../test_model/test_shard_llama.py | 206 ++++++++----- 16 files changed, 520 insertions(+), 234 deletions(-) create mode 100644 colossalai/params.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index b818209a6..c210ca91e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -945,7 +945,8 @@ class HybridParallelPlugin(PipelinePluginBase): gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. - overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism + overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism. + fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism """ def __init__( @@ -1119,6 +1120,7 @@ class HybridParallelPlugin(PipelinePluginBase): parallel_output=parallel_output, make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, + fp8_communication=fp8_communication, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/params.py b/colossalai/params.py new file mode 100644 index 000000000..4058ce430 --- /dev/null +++ b/colossalai/params.py @@ -0,0 +1 @@ +to_cast = [] diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index e514f435e..66bcd0295 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -12,7 +12,6 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied. fp8_format: e4m3 or e5m2 - Returns: Tuples: A tuple (fp8_tensor, scale) """ @@ -39,12 +38,10 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: r""" - Args: inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2]. scale: scaling factor returned by cast_to_fp8 function. ret_type: the datatype of the returned tensor. - Returns: torch.Tensor """ @@ -62,11 +59,9 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_reduce but during communication the data is cast to fp8 format. - Args: tensor: torch.Tensor in fp32, fp16, bf16 datatype. fp8_format: e4m3 or e5m2 - Returns: None """ @@ -170,3 +165,40 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: if del_metadata: del inp["fp8_scale"] + + +def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e4m3") -> None: + r""" + This is an in-place operation for compressed all_reduce using fp8. + It works like dist.all_reduce but during communication the data is cast to fp8 format. + + Args: + tensor: torch.Tensor in fp32, fp16, bf16 datatype. + fp8_format: e4m3 or e5m2 + + Returns: + None + """ + + input_type = output.dtype + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + scale_list = [] + cast_input_list = [] + output_chunks = [] + output_scale_list = [] + for input in input_list: + ret, scale = cast_to_fp8(input, fp8_format=fp8_format) + scale_list.append(scale) + ret = ret.view(torch.uint8) + cast_input_list.append(ret) + output_chunks.append(torch.empty_like(ret)) + output_scale_list.append(torch.empty_like(scale)) + dist.all_to_all(output_chunks, cast_input_list, group=group) + dist.all_to_all(output_scale_list, scale_list, group=group) + + summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + for scale, out in zip(output_scale_list, output_chunks): + out = out.view(fp8_type) + summed_out += cast_from_fp8(out, scale, input_type) + output.data = summed_out diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 82d37bb4c..12ed1d409 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -14,6 +14,8 @@ try: except ImportError: _grad_accum_fusion_available = False +from colossalai.quantization.fp8 import all_reduce_fp8, cast_from_fp8, cast_to_fp8, reduce_scatter_fp8 + class FusedLayerNormAffineFunction1D(torch.autograd.Function): r"""Layernorm @@ -59,11 +61,12 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce + ctx.fp8_communication = fp8_communication output = torch.matmul(input_, weight) @@ -76,6 +79,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias + fp8_communication = ctx.fp8_communication # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. weight = weight.view(weight.shape) @@ -90,7 +94,9 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1]) - if ctx.async_grad_allreduce: + if fp8_communication and ctx.async_grad_allreduce: + _reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication) + elif ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have @@ -99,10 +105,10 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None - if ctx.async_grad_allreduce: + if ctx.async_grad_allreduce and not fp8_communication: handle.wait() - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None class LinearWithAsyncCommunication(torch.autograd.Function): @@ -111,11 +117,12 @@ class LinearWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce + ctx.fp8_communication = fp8_communication if bias is not None: output = F.linear(input_, weight, bias) else: @@ -127,6 +134,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function): def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias + fp8_communication = ctx.fp8_communication # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. if use_bias: @@ -142,7 +150,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: # Asynchronous all-reduce - handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + if fp8_communication: + all_reduce_fp8(grad_input, group=ctx.process_group) + else: + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py @@ -161,10 +172,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function): grad_bias = grad_output.sum(dim=0) if use_bias else None - if ctx.async_grad_allreduce: + if ctx.async_grad_allreduce and not fp8_communication: handle.wait() - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): @@ -232,17 +243,18 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, dim): + def forward(ctx, input_, process_group, dim, fp8_communication=False): ctx.process_group = process_group ctx.dim = dim + ctx.fp8_communication = fp8_communication - return _gather(input_, dim, process_group) + return _gather(input_, dim, process_group, fp8_communication) @staticmethod def backward(ctx, grad_output): dim = ctx.dim process_group = ctx.process_group - + fp8_communication = ctx.fp8_communication # do reduce-scatter new_shape = list(grad_output.shape) assert ( @@ -253,9 +265,13 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function): item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim) ] output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) - dist.reduce_scatter(output, grad_list, group=process_group) - return output, None, None + if fp8_communication: + reduce_scatter_fp8(output, grad_list, group=process_group) + else: + dist.reduce_scatter(output, grad_list, group=process_group) + + return output, None, None, None class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): @@ -546,9 +562,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, dim): + def forward(ctx, input_, process_group, dim, fp8_communication=False): ctx.dim = dim ctx.process_group = process_group + ctx.fp8_communication = fp8_communication # do reduce-scatter new_shape = list(input_.shape) @@ -558,7 +575,11 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function): new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) - dist.reduce_scatter(output, input_list, group=process_group) + if fp8_communication: + # if False: + reduce_scatter_fp8(output, input_list, group=process_group) + else: + dist.reduce_scatter(output, input_list, group=process_group) return output @@ -566,8 +587,9 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function): def backward(ctx, grad_output): dim = ctx.dim process_group = ctx.process_group + fp8_communication = ctx.fp8_communication - return _gather(grad_output, dim, process_group), None, None + return _gather(grad_output, dim, process_group, fp8_communication), None, None, None class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): @@ -582,13 +604,16 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring): + def forward( + ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication + ): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim ctx.overlap = overlap + ctx.fp8_communication = fp8_communication if ring is True: input_to_gather = {} @@ -605,7 +630,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): ) else: - input_parallel = _gather(input_, dim, process_group) + input_parallel = _gather(input_, dim, process_group, fp8_communication) output = torch.matmul(input_parallel, weight) @@ -620,6 +645,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): dim = ctx.dim process_group = ctx.process_group overlap = ctx.overlap + fp8_communication = ctx.fp8_communication # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm weight = weight.view(weight.shape) @@ -627,7 +653,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): bias = bias.view(bias.shape) if not overlap: - input_parallel = _gather(input_, dim, process_group) + input_parallel = _gather(input_, dim, process_group, fp8_communication) total_input = input_parallel grad_input = grad_output.matmul(weight.T) @@ -687,7 +713,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -702,17 +728,20 @@ class _SplitForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group, grad_scale=None): + def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False): ctx.process_group = process_group ctx.dim = dim ctx.grad_scale = grad_scale + ctx.fp8_communication = fp8_communication return _split(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): if ctx.grad_scale is not None: grad_output = grad_output * ctx.grad_scale - return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None + + # to_cast.append(grad_output.cpu().detach().numpy()) + return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, "e4m3"), None, None, None, None class _ReduceForward(torch.autograd.Function): @@ -725,12 +754,12 @@ class _ReduceForward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group): - return _reduce(input_, process_group) + def forward(ctx, input_, process_group, fp8_communication=False): + return _reduce(input_, process_group, fp8_communication) @staticmethod def backward(ctx, grad_output): - return grad_output, None + return grad_output, None, None class _ReduceBackward(torch.autograd.Function): @@ -743,13 +772,15 @@ class _ReduceBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group): + def forward(ctx, input_, process_group, fp8_communication=False): ctx.process_group = process_group + ctx.fp8_communication = fp8_communication return input_ @staticmethod def backward(ctx, grad_output): - return _reduce(grad_output, ctx.process_group), None + fp8_communication = ctx.fp8_communication + return _reduce(grad_output, ctx.process_group, fp8_communication), None, None class _GatherForwardSplitBackward(torch.autograd.Function): @@ -762,17 +793,18 @@ class _GatherForwardSplitBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group, grad_scale=None): + def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_comm=False): ctx.process_group = process_group ctx.dim = dim ctx.grad_scale = grad_scale - return _gather(input_, dim, process_group) + + return _gather(input_, dim, process_group, fp8_comm=fp8_comm, fp8_format="e4m3") @staticmethod def backward(ctx, grad_output): if ctx.grad_scale is not None: grad_output = grad_output * ctx.grad_scale - return _split(grad_output, ctx.dim, ctx.process_group), None, None, None + return _split(grad_output, ctx.dim, ctx.process_group), None, None, None, None class _AllToAll(torch.autograd.Function): @@ -786,26 +818,43 @@ class _AllToAll(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, scatter_dim, gather_dim): + def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication): ctx.process_group = process_group ctx.scatter_dim = scatter_dim ctx.gather_dim = gather_dim + ctx.fp8_communication = fp8_communication world_size = dist.get_world_size(process_group) bsz, _, _ = input_.shape # using all_to_all_single when batch size is 1 if bsz == 1: - return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim) + return _all_to_all_single( + input_, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" + ) else: - return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) + return _all_to_all( + input_, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" + ) @staticmethod - def backward(ctx, *grad_output): + def backward(ctx, grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim - return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) - return (return_grad, None, None, None) + ctx.fp8_communication + world_size = dist.get_world_size(process_group) + bsz, _, _ = grad_output.shape + + if bsz == 1: + return_grad = _all_to_all_single( + grad_output, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" + ) + else: + return_grad = _all_to_all( + grad_output, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" + ) + + return (return_grad, None, None, None, None) class HookParameter(torch.autograd.Function): @@ -831,12 +880,15 @@ def hook_parameter_in_backward(input, weight=None, bias=None): return HookParameter.apply(input, weight, bias) -def _reduce(input_, process_group): +def _reduce(input_, process_group, fp8_communication=False): # skip if only one rank involved if dist.get_world_size(process_group) == 1: return input_ else: - dist.all_reduce(input_, group=process_group) + if fp8_communication: + all_reduce_fp8(input_, group=process_group) + else: + dist.all_reduce(input_, group=process_group) return input_ @@ -860,19 +912,78 @@ def _split(input_, dim=-1, process_group=None): return output -def _gather(input_, dim=-1, process_group=None): +from colossalai.params import to_cast + + +def _gather(input_, dim=-1, process_group=None, fp8_comm=False, fp8_format="e4m3"): # skip if only one rank involved world_size = dist.get_world_size(process_group) if world_size == 1: return input_ # all gather - input_ = input_.contiguous() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - torch.distributed.all_gather(tensor_list, input_, group=process_group) + import torch.distributed as dista - # concat - output = torch.cat(tensor_list, dim=dim).contiguous() + from colossalai.zero.low_level._utils import has_inf_or_nan + + if fp8_comm: + # if False: + if has_inf_or_nan(input_): + print("input has nan") + exit(0) + input_type = input_.dtype + to_cast.append(input_) + ret, scale = cast_to_fp8(input_, fp8_format="e5m2") + if has_inf_or_nan(ret): + import pdb + + pdb.set_trace() + print("cast has nan") + # exit(0) + dista.barrier() + fp8_type = ret.dtype + input_ = ret.view(torch.uint8) + input_ = input_.contiguous() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + scale = torch.tensor(scale, dtype=torch.float32).to(input_.device) + # import torch.distributed as dista + # if dista.get_rank()==0: + # import pdb + # pdb.set_trace() + # dista.barrier() + scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)] + + scale = torch.tensor(scale).to(input_.device) + torch.distributed.all_gather(tensor_list, input_, group=process_group) + torch.distributed.all_gather(scale_list, scale, group=process_group) + + cast_tensor_list = [] + for output, scale in zip(tensor_list, scale_list): + output = output.view(fp8_type) + output = cast_from_fp8(output, scale, input_type) + if has_inf_or_nan(output) and dista.get_rank() == 0: + print("casted_output has nan") + import pdb + + pdb.set_trace() + dista.barrier() + + cast_tensor_list.append(output) + + output = torch.cat(cast_tensor_list, dim=dim).contiguous() + + if has_inf_or_nan(output): + print("output has nan") + exit(0) + # import pdb + # pdb.set_trace() + dista.barrier() + + else: + input_ = input_.contiguous() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, input_, group=process_group) + output = torch.cat(tensor_list, dim=dim).contiguous() return output @@ -901,14 +1012,31 @@ def _reduce_scatter(input_, dim=1, process_group=None): return output -def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): - input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] - output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] - dist.all_to_all(output_list, input_list, group=group) +def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_comm=False, fp8_format="e5m2"): + if fp8_comm: + input_type = input_.dtype + ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) + fp8_type = ret.dtype + input_ = ret.view(torch.uint8) + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + dist.all_gather(scale_list, scale, group=group) + cast_tensor_list = [] + for output, scale in zip(output_list, scale_list): + output = output.view(fp8_type) + output = cast_from_fp8(output, scale, input_type) + cast_tensor_list.append(output) + output_list = cast_tensor_list + else: + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) return torch.cat(output_list, dim=gather_dim).contiguous() -def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): +def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim, fp8_comm=False, fp8_format="e5m2"): inp_shape = list(input_.shape) inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size if scatter_dim < 2: @@ -920,8 +1048,24 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): .contiguous() ) - output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=group) + if fp8_comm: + input_type = input_t.dtype + ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format) + fp8_type = ret.dtype + input_t = ret.view(torch.uint8) + output = torch.empty_like(input_t) + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(seq_world_size)] + dist.all_to_all_single(output, input_t, group=group) + dist.all_gather(scale_list, scale, group=group) + cast_tensor_list = [] + for output_part, scale in zip(output, scale_list): + output_part = output_part.view(fp8_type) + output_part = cast_from_fp8(output_part, scale, input_type) + cast_tensor_list.append(output_part) + output = torch.stack(cast_tensor_list, dim=0) + else: + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) if scatter_dim < 2: output = output.transpose(0, 1).contiguous() @@ -935,12 +1079,16 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): ).contiguous() -def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): - return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) +def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + return MatmulWithAsyncCommunication.apply( + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication + ) -def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): - return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) +def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + return LinearWithAsyncCommunication.apply( + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication + ) def linear_gather_forward_reducescatter_backward( @@ -951,12 +1099,12 @@ def linear_gather_forward_reducescatter_backward( ) -def gather_forward_reducescatter_backward(input_, process_group, dim): - return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim) +def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False): + return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication) -def reducescatter_forward_gather_backward(input_, process_group, dim): - return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim) +def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False): + return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication) def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False): @@ -964,28 +1112,28 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False ): return _MatmulWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication ) -def gather_forward_split_backward(input_, dim, process_group, grad_scale=None): - return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale) +def gather_forward_split_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False): + return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) -def split_forward_gather_backward(input_, dim, process_group, grad_scale=None): - return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale) +def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False): + return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) -def reduce_forward(input_, process_group): - return _ReduceForward.apply(input_, process_group) +def reduce_forward(input_, process_group, fp8_communication=False): + return _ReduceForward.apply(input_, process_group, fp8_communication) -def reduce_backward(input_, process_group): - return _ReduceBackward.apply(input_, process_group) +def reduce_backward(input_, process_group, fp8_communication=False): + return _ReduceBackward.apply(input_, process_group, fp8_communication=fp8_communication) -def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): - return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) +def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_comm=False): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_comm) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 37c754241..af25c398b 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -84,6 +84,7 @@ class Linear1D_Col(ParallelModule): bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, **kwargs, ): super().__init__(weight=weight, bias_=bias_, **kwargs) @@ -98,6 +99,7 @@ class Linear1D_Col(ParallelModule): self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -201,10 +203,12 @@ class Linear1D_Col(ParallelModule): bias = self.bias if not self.skip_bias_add else None if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + output_parallel = linear_with_async_comm( + input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication + ) elif self.seq_parallel_mode == "split_gather": input_parallel = gather_forward_reducescatter_backward( - input_parallel, self.process_group, self.seq_parallel_dim + input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication ) output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) elif self.seq_parallel_mode == "ring": @@ -264,6 +268,7 @@ class Linear1D_Row(ParallelModule): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1, + fp8_communication: bool = False, ): super().__init__() @@ -278,6 +283,7 @@ class Linear1D_Row(ParallelModule): self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -398,7 +404,9 @@ class Linear1D_Row(ParallelModule): ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions ) - input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + input_ = split_forward_gather_backward( + input_, dim=-1, process_group=self.process_group, fp8_comm=self.fp8_communication + ) if self.stream_chunk_num > 1: if self.training: @@ -418,11 +426,11 @@ class Linear1D_Row(ParallelModule): else: if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) - output = reduce_forward(output_parallel, self.process_group) + output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) elif self.seq_parallel_mode == "split_gather": output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reducescatter_forward_gather_backward( - output_parallel, self.process_group, self.seq_parallel_dim + output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication ) elif self.seq_parallel_mode == "ring": output = linear_reducescatter_forward_gather_backward( diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 0f6595a7c..d8425b58d 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -183,6 +183,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, ): super().__init__() @@ -197,6 +198,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.n_fused = n_fused self.process_group = process_group self.async_communication = async_communication + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -314,14 +316,26 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): if self.seq_parallel_mode is None: # Set up backprop all-reduce. - input_parallel = reduce_backward(input_, self.process_group) + input_parallel = reduce_backward(input_, self.process_group, fp8_communication=self.fp8_communication) output_parallel = matmul_with_async_comm( - input_parallel, self.weight, bias, self.process_group, self.async_communication + input_parallel, + self.weight, + bias, + self.process_group, + self.async_communication, + fp8_communication=self.fp8_communication, ) elif self.seq_parallel_mode == "split_gather": input_parallel = input_ output_parallel = matmul_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap + input_parallel, + self.weight, + bias, + self.process_group, + True, + 1, + self.overlap, + fp8_communication=self.fp8_communication, ) elif self.seq_parallel_mode == "ring": input_parallel = input_ @@ -331,7 +345,9 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): if self.gather_output: # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + output = gather_forward_split_backward( + output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) else: output = output_parallel @@ -379,6 +395,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1, + fp8_communication: bool = False, ): super().__init__() @@ -392,6 +409,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): self.process_group = process_group self.seq_parallel_mode = seq_parallel_mode self.num_partitions = dist.get_world_size(self.process_group) + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -514,7 +532,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions ) - input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + input_ = split_forward_gather_backward( + input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) if self.stream_chunk_num > 1: if self.training: @@ -535,13 +555,20 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): else: if self.seq_parallel_mode is None: output_parallel = torch.matmul(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + output = reduce_forward(output_parallel, self.process_group, self.fp8_communication) elif self.seq_parallel_mode == "split_gather": output_parallel = torch.matmul(input_, self.weight) - output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + output = reducescatter_forward_gather_backward( + output_parallel, + self.process_group, + 1, + self.fp8_communication, + ) elif self.seq_parallel_mode == "ring": output_parallel = torch.matmul(input_, self.weight) - output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + output = reducescatter_forward_gather_backward( + output_parallel, self.process_group, 1, self.fp8_communication + ) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index aa75bab11..beaa47952 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -1137,6 +1137,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, dim=1, process_group=shard_config.sequence_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -1204,6 +1205,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, dim=1, process_group=shard_config.sequence_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) hidden_states = self.ln_f(hidden_states) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index bf5ce45a8..05c9fe3bd 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -460,7 +460,7 @@ class LlamaPipelineForwards: return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): +def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self, hidden_states: torch.Tensor, @@ -510,9 +510,9 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_comm=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_comm=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_comm=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -592,7 +592,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, return forward -def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): +def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) def forward( @@ -659,9 +659,13 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, fp8_comm=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, 1 / sp_size, fp8_comm=shard_config.fp8_communication + ) hidden_states = inputs_embeds # decoder layers @@ -706,9 +710,13 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size= hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_comm=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_comm=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index cfe20000a..bb6269737 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -110,14 +110,13 @@ class GPT2Policy(Policy): "n_fused": 3, "seq_parallel_mode": sp_mode, "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel_mode": sp_mode, - }, + kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="mlp.c_fc", @@ -127,14 +126,13 @@ class GPT2Policy(Policy): "seq_parallel_mode": sp_mode, "overlap": overlap, "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel_mode": sp_mode, - }, + kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 85ec6717d..be21b33e1 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -134,37 +134,37 @@ class LlamaPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), ], ) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index b64300366..7372e06c2 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -29,6 +29,7 @@ class ShardConfig: enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. + fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None @@ -47,6 +48,7 @@ class ShardConfig: gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) ep_group: Optional[ProcessGroup] = None + fp8_communication: bool = False # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 8a59ab683..d9246ff49 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -224,7 +224,10 @@ def main(): # modify the param accordingly for finetuning test cases plugin = HybridParallelPlugin( tp_size=1, - pp_size=2, + pp_size=1, + sp_size=2, + enable_sequence_parallelism=True, + sequence_parallelism_mode="all_to_all", num_microbatches=None, pp_style="interleaved", num_model_chunks=2, diff --git a/examples/language/bert/test_ci.sh b/examples/language/bert/test_ci.sh index fc4eacf6f..de35e0fa5 100755 --- a/examples/language/bert/test_ci.sh +++ b/examples/language/bert/test_ci.sh @@ -5,7 +5,7 @@ pip install -r requirements.txt FAIL_LIMIT=3 -for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do +for plugin in "hybrid_parallel"; do for i in $(seq 1 $FAIL_LIMIT); do torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" && break echo "Failed $i times" diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 9b3a10160..ae80ddad5 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -218,8 +218,11 @@ def main(): elif args.plugin == "hybrid_parallel": # modify the param accordingly for finetuning test cases plugin = HybridParallelPlugin( - tp_size=1, - pp_size=2, + tp_size=2, + pp_size=1, + sp_size=2, + sequence_parallelism_mode="split_gather", + enable_sequence_parallelism=True, num_microbatches=None, microbatch_size=1, enable_all_optimization=True, @@ -318,3 +321,7 @@ def main(): if __name__ == "__main__": main() + if dist.get_rank() == 0: + import pdb + + pdb.set_trace() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index f9e368c0e..8fa0997d5 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -51,7 +51,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: - atol, rtol = 5e-3, 5e-3 + atol, rtol = 5e-2, 5e-2 col_layer_grads = get_grad_tensors_for_check( gpt2, sharded_gpt2, @@ -97,7 +97,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: - atol, rtol = 5e-3, 5e-3 + atol, rtol = 5e-2, 5e-2 if org_model.__class__.__name__ == "GPT2Model": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) @@ -131,17 +131,47 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": False, - "use_lazy_init": True, - "precision": "fp32", - "initial_scale": 1, - }, + # { + # "tp_size": 4, + # "pp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring", + # "enable_flash_attention": False, + # "use_lazy_init": True, + # "precision": "fp32", + # "initial_scale": 1, + # }, + # { + # "tp_size": 4, + # "pp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "split_gather", + # "enable_flash_attention": False, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 2, + # "pp_size": 2, + # "num_microbatches": 4, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, { "tp_size": 4, "pp_size": 1, @@ -152,25 +182,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 4, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, + "fp8_communication": True, }, ], ) @@ -272,4 +284,4 @@ def test_gpt2_3d(): if __name__ == "__main__": test_gpt2() - test_gpt2_3d() + # test_gpt2_3d() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 8fe18f69b..d32b0684e 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -34,7 +34,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if enable_gradient_checkpointing: # org_model.gradient_checkpointing_enable() sharded_model.unwrap().gradient_checkpointing_enable() - org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster ) @@ -71,7 +70,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-2, rtol=5e-2, check_dtype=False) # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -109,7 +108,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: - atol, rtol = 5e-3, 5e-3 + atol, rtol = 5e-2, 5e-2 if org_model.__class__.__name__ == "LlamaModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) @@ -121,7 +120,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: - atol, rtol = 5e-3, 5e-3 + atol, rtol = 5e-2, 5e-2 try: check_weight( llama_model, @@ -146,104 +145,141 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { # Test ring + Flash attention + # { # Test ring + Flash attention + # "tp_size": 2, + # "pp_size": 1, + # "sp_size": 2, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring", + # "enable_flash_attention": True, + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { # Ulysess + Flash attention + # "tp_size": 1, + # "pp_size": 2, + # "sp_size": 2, + # "num_microbatches": 2, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "all_to_all", + # "enable_flash_attention": True, + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 1, + # "sp_size": 2, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "all_to_all", + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 4, + # "pp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "split_gather", + # "enable_flash_attention": False, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 2, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # "enable_gradient_checkpointing": True, + # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 4, + # "use_lazy_init": False, + # "precision": "fp32", + # "enable_gradient_checkpointing": True, + # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), + # }, + # { + # "tp_size": 2, + # "pp_size": 1, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, + { "tp_size": 2, "pp_size": 1, "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { # Ulysess + Flash attention - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - "enable_gradient_checkpointing": True, - "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 4, - "use_lazy_init": False, - "precision": "fp32", - "enable_gradient_checkpointing": True, - "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), - }, - { - "tp_size": 2, - "pp_size": 1, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, "use_lazy_init": True, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, + "fp8_communication": True, + }, + { + "tp_size": 2, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": False, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + "fp8_communication": True, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + "fp8_communication": True, }, ], ) def run_llama_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_sequence_classification") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): try: check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) except Exception as e: - print(f"Failed config: {test_config}") + print(f"Failed config out: {test_config}") raise e clear_layout_converter() @@ -291,7 +327,7 @@ def run_llama_test(test_config): ], ) def run_llama_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_sequence_classification") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): try: @@ -333,4 +369,4 @@ def test_llama_3d(): if __name__ == "__main__": test_llama() - test_llama_3d() + # test_llama_3d() From 5a310b9ee177eab8335c044055eb2e6b55a62dac Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Wed, 17 Jul 2024 02:56:07 +0000 Subject: [PATCH 2/3] fix rebase --- colossalai/params.py | 1 - colossalai/quantization/fp8.py | 16 +- colossalai/shardformer/layer/_operation.py | 100 +++++----- colossalai/shardformer/modeling/llama.py | 24 ++- colossalai/shardformer/policies/llama.py | 14 +- examples/language/bert/finetune.py | 5 +- examples/language/bert/test_ci.sh | 2 +- .../gpt/hybridparallelism/finetune.py | 10 +- .../test_model/test_shard_gpt2.py | 78 ++++---- .../test_model/test_shard_llama.py | 176 +++++++----------- 10 files changed, 194 insertions(+), 232 deletions(-) delete mode 100644 colossalai/params.py diff --git a/colossalai/params.py b/colossalai/params.py deleted file mode 100644 index 4058ce430..000000000 --- a/colossalai/params.py +++ /dev/null @@ -1 +0,0 @@ -to_cast = [] diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 66bcd0295..a7ed61252 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -55,7 +55,7 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt return ret.to(ret_type) -def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: +def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", group=None) -> None: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_reduce but during communication the data is cast to fp8 format. @@ -66,7 +66,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: None """ - world_size = dist.get_world_size() + world_size = dist.get_world_size(group=group) input_type = tensor.dtype input_shape = tensor.shape input_device = tensor.device @@ -83,19 +83,19 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)] else: output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] - dist.all_to_all(output_chunks, input_chunks) + dist.all_to_all(output_chunks, input_chunks, group=group) scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] - dist.all_gather(scale_list, scale) + dist.all_gather(scale_list, scale, group=group) summed_out = torch.zeros_like(output_chunks[0]).to(input_type) for scale, out in zip(scale_list, output_chunks): out = out.view(fp8_type) summed_out += cast_from_fp8(out, scale, input_type) summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) - dist.all_gather(scale_list, scale) + dist.all_gather(scale_list, scale, group=group) tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0)) - dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8)) + dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group) for i in range(world_size): tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] tensor_out = torch.cat(tensor_list, dim=0) @@ -169,8 +169,8 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e4m3") -> None: r""" - This is an in-place operation for compressed all_reduce using fp8. - It works like dist.all_reduce but during communication the data is cast to fp8 format. + This is an in-place operation for compressed reduce_scatter using fp8. + It works like dist.reduce_scatter but during communication the data is cast to fp8 format. Args: tensor: torch.Tensor in fp32, fp16, bf16 datatype. diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 12ed1d409..c1a04357e 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -94,7 +94,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1]) - if fp8_communication and ctx.async_grad_allreduce: + if ctx.async_grad_allreduce and fp8_communication: _reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication) elif ctx.async_grad_allreduce: # Asynchronous all-reduce @@ -117,12 +117,11 @@ class LinearWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce - ctx.fp8_communication = fp8_communication if bias is not None: output = F.linear(input_, weight, bias) else: @@ -134,7 +133,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function): def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias - fp8_communication = ctx.fp8_communication # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. if use_bias: @@ -150,10 +148,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: # Asynchronous all-reduce - if fp8_communication: - all_reduce_fp8(grad_input, group=ctx.process_group) - else: - handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py @@ -172,7 +167,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function): grad_bias = grad_output.sum(dim=0) if use_bias else None - if ctx.async_grad_allreduce and not fp8_communication: + if ctx.async_grad_allreduce: handle.wait() return grad_input, grad_weight, grad_bias, None, None, None, None @@ -243,18 +238,16 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, dim, fp8_communication=False): + def forward(ctx, input_, process_group, dim): ctx.process_group = process_group ctx.dim = dim - ctx.fp8_communication = fp8_communication - return _gather(input_, dim, process_group, fp8_communication) + return _gather(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): dim = ctx.dim process_group = ctx.process_group - fp8_communication = ctx.fp8_communication # do reduce-scatter new_shape = list(grad_output.shape) assert ( @@ -266,10 +259,7 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function): ] output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) - if fp8_communication: - reduce_scatter_fp8(output, grad_list, group=process_group) - else: - dist.reduce_scatter(output, grad_list, group=process_group) + dist.reduce_scatter(output, grad_list, group=process_group) return output, None, None, None @@ -576,7 +566,6 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function): input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) if fp8_communication: - # if False: reduce_scatter_fp8(output, input_list, group=process_group) else: dist.reduce_scatter(output, input_list, group=process_group) @@ -588,8 +577,7 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function): dim = ctx.dim process_group = ctx.process_group fp8_communication = ctx.fp8_communication - - return _gather(grad_output, dim, process_group, fp8_communication), None, None, None + return _gather(grad_output, dim, process_group, fp8_communication=fp8_communication), None, None, None class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): @@ -793,12 +781,12 @@ class _GatherForwardSplitBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_comm=False): + def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False): ctx.process_group = process_group ctx.dim = dim ctx.grad_scale = grad_scale - return _gather(input_, dim, process_group, fp8_comm=fp8_comm, fp8_format="e4m3") + return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3") @staticmethod def backward(ctx, grad_output): @@ -829,11 +817,23 @@ class _AllToAll(torch.autograd.Function): # using all_to_all_single when batch size is 1 if bsz == 1: return _all_to_all_single( - input_, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" + input_, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e5m2", ) else: return _all_to_all( - input_, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" + input_, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e5m2", ) @staticmethod @@ -841,17 +841,29 @@ class _AllToAll(torch.autograd.Function): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim - ctx.fp8_communication + fp8_communication = ctx.fp8_communication world_size = dist.get_world_size(process_group) bsz, _, _ = grad_output.shape if bsz == 1: return_grad = _all_to_all_single( - grad_output, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" + grad_output, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e5m2", ) else: return_grad = _all_to_all( - grad_output, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" + grad_output, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e5m2", ) return (return_grad, None, None, None, None) @@ -912,10 +924,7 @@ def _split(input_, dim=-1, process_group=None): return output -from colossalai.params import to_cast - - -def _gather(input_, dim=-1, process_group=None, fp8_comm=False, fp8_format="e4m3"): +def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e4m3"): # skip if only one rank involved world_size = dist.get_world_size(process_group) if world_size == 1: @@ -926,13 +935,12 @@ def _gather(input_, dim=-1, process_group=None, fp8_comm=False, fp8_format="e4m3 from colossalai.zero.low_level._utils import has_inf_or_nan - if fp8_comm: + if fp8_communication: # if False: if has_inf_or_nan(input_): print("input has nan") exit(0) input_type = input_.dtype - to_cast.append(input_) ret, scale = cast_to_fp8(input_, fp8_format="e5m2") if has_inf_or_nan(ret): import pdb @@ -1012,8 +1020,8 @@ def _reduce_scatter(input_, dim=1, process_group=None): return output -def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_comm=False, fp8_format="e5m2"): - if fp8_comm: +def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"): + if fp8_communication: input_type = input_.dtype ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) fp8_type = ret.dtype @@ -1036,7 +1044,9 @@ def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_comm=Fal return torch.cat(output_list, dim=gather_dim).contiguous() -def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim, fp8_comm=False, fp8_format="e5m2"): +def _all_to_all_single( + input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2" +): inp_shape = list(input_.shape) inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size if scatter_dim < 2: @@ -1048,7 +1058,7 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim, f .contiguous() ) - if fp8_comm: + if fp8_communication: input_type = input_t.dtype ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format) fp8_type = ret.dtype @@ -1085,10 +1095,8 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre ) -def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): - return LinearWithAsyncCommunication.apply( - input_, weight, bias, process_group, async_grad_allreduce, fp8_communication - ) +def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): + return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) def linear_gather_forward_reducescatter_backward( @@ -1099,8 +1107,8 @@ def linear_gather_forward_reducescatter_backward( ) -def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False): - return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication) +def gather_forward_reducescatter_backward(input_, process_group, dim): + return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim) def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False): @@ -1132,8 +1140,8 @@ def reduce_forward(input_, process_group, fp8_communication=False): def reduce_backward(input_, process_group, fp8_communication=False): - return _ReduceBackward.apply(input_, process_group, fp8_communication=fp8_communication) + return _ReduceBackward.apply(input_, process_group, fp8_communication) -def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_comm=False): - return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_comm) +def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 05c9fe3bd..26d1b6e4c 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -510,9 +510,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group, fp8_comm=shard_config.fp8_communication) - key_states = all_to_all_comm(key_states, sp_group, fp8_comm=shard_config.fp8_communication) - value_states = all_to_all_comm(value_states, sp_group, fp8_comm=shard_config.fp8_communication) + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -660,11 +660,16 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N if sp_mode in ["ring", "split_gather"]: inputs_embeds = split_forward_gather_backward( - inputs_embeds, 1, sp_group, fp8_comm=shard_config.fp8_communication + inputs_embeds, + 1, + sp_group, ) elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward( - inputs_embeds, 1, sp_group, 1 / sp_size, fp8_comm=shard_config.fp8_communication + inputs_embeds, + 1, + sp_group, + 1 / sp_size, ) hidden_states = inputs_embeds @@ -711,11 +716,16 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N if sp_mode == "ring" or sp_mode == "split_gather": hidden_states = gather_forward_split_backward( - hidden_states, 1, sp_group, fp8_comm=shard_config.fp8_communication + hidden_states, + 1, + sp_group, ) elif sp_mode == "all_to_all": hidden_states = gather_forward_split_backward( - hidden_states, 1, sp_group, grad_scale=sp_size, fp8_comm=shard_config.fp8_communication + hidden_states, + 1, + sp_group, + grad_scale=sp_size, ) # add hidden states from the last decoder layer diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index be21b33e1..85ec6717d 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -134,37 +134,37 @@ class LlamaPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict(seq_parallel_mode=sp_mode), ), ], ) diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index d9246ff49..8a59ab683 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -224,10 +224,7 @@ def main(): # modify the param accordingly for finetuning test cases plugin = HybridParallelPlugin( tp_size=1, - pp_size=1, - sp_size=2, - enable_sequence_parallelism=True, - sequence_parallelism_mode="all_to_all", + pp_size=2, num_microbatches=None, pp_style="interleaved", num_model_chunks=2, diff --git a/examples/language/bert/test_ci.sh b/examples/language/bert/test_ci.sh index de35e0fa5..fc4eacf6f 100755 --- a/examples/language/bert/test_ci.sh +++ b/examples/language/bert/test_ci.sh @@ -5,7 +5,7 @@ pip install -r requirements.txt FAIL_LIMIT=3 -for plugin in "hybrid_parallel"; do +for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do for i in $(seq 1 $FAIL_LIMIT); do torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" && break echo "Failed $i times" diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index ae80ddad5..a4b48fa05 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -220,9 +220,9 @@ def main(): plugin = HybridParallelPlugin( tp_size=2, pp_size=1, - sp_size=2, - sequence_parallelism_mode="split_gather", - enable_sequence_parallelism=True, + sp_size=1, + # sequence_parallelism_mode="split_gather", + # enable_sequence_parallelism=True, num_microbatches=None, microbatch_size=1, enable_all_optimization=True, @@ -321,7 +321,3 @@ def main(): if __name__ == "__main__": main() - if dist.get_rank() == 0: - import pdb - - pdb.set_trace() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 8fa0997d5..f9e368c0e 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -51,7 +51,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: - atol, rtol = 5e-2, 5e-2 + atol, rtol = 5e-3, 5e-3 col_layer_grads = get_grad_tensors_for_check( gpt2, sharded_gpt2, @@ -97,7 +97,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: - atol, rtol = 5e-2, 5e-2 + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == "GPT2Model": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) @@ -131,47 +131,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # { - # "tp_size": 4, - # "pp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring", - # "enable_flash_attention": False, - # "use_lazy_init": True, - # "precision": "fp32", - # "initial_scale": 1, - # }, - # { - # "tp_size": 4, - # "pp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "split_gather", - # "enable_flash_attention": False, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 2, - # "pp_size": 2, - # "num_microbatches": 4, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, { "tp_size": 4, "pp_size": 1, @@ -182,7 +152,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, - "fp8_communication": True, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, }, ], ) @@ -284,4 +272,4 @@ def test_gpt2_3d(): if __name__ == "__main__": test_gpt2() - # test_gpt2_3d() + test_gpt2_3d() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index d32b0684e..8fe18f69b 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -34,6 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if enable_gradient_checkpointing: # org_model.gradient_checkpointing_enable() sharded_model.unwrap().gradient_checkpointing_enable() + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster ) @@ -70,7 +71,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] - assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-2, rtol=5e-2, check_dtype=False) + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -108,7 +109,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: - atol, rtol = 5e-2, 5e-2 + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == "LlamaModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) @@ -120,7 +121,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: - atol, rtol = 5e-2, 5e-2 + atol, rtol = 5e-3, 5e-3 try: check_weight( llama_model, @@ -145,117 +146,31 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # { # Test ring + Flash attention - # "tp_size": 2, - # "pp_size": 1, - # "sp_size": 2, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring", - # "enable_flash_attention": True, - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { # Ulysess + Flash attention - # "tp_size": 1, - # "pp_size": 2, - # "sp_size": 2, - # "num_microbatches": 2, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "all_to_all", - # "enable_flash_attention": True, - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 1, - # "sp_size": 2, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "all_to_all", - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 4, - # "pp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "split_gather", - # "enable_flash_attention": False, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 2, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # "enable_gradient_checkpointing": True, - # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 4, - # "use_lazy_init": False, - # "precision": "fp32", - # "enable_gradient_checkpointing": True, - # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), - # }, - # { - # "tp_size": 2, - # "pp_size": 1, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, - { + { # Test ring + Flash attention "tp_size": 2, "pp_size": 1, "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, "use_lazy_init": True, - "zero_stage": 1, + "zero_stage": 2, "precision": "fp16", "initial_scale": 1, - "fp8_communication": True, }, - { - "tp_size": 2, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": False, + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, - "fp8_communication": True, }, { "tp_size": 1, @@ -268,18 +183,67 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, - "fp8_communication": True, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, }, ], ) def run_llama_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_sequence_classification") + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): try: check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) except Exception as e: - print(f"Failed config out: {test_config}") + print(f"Failed config: {test_config}") raise e clear_layout_converter() @@ -327,7 +291,7 @@ def run_llama_test(test_config): ], ) def run_llama_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_sequence_classification") + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): try: @@ -369,4 +333,4 @@ def test_llama_3d(): if __name__ == "__main__": test_llama() - # test_llama_3d() + test_llama_3d() From 6a20f07b8090b2fe801eb0bc0b33b397beb6d1fd Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Wed, 17 Jul 2024 05:33:38 +0000 Subject: [PATCH 3/3] remove all to all --- colossalai/quantization/fp8.py | 4 +- colossalai/shardformer/layer/_operation.py | 153 +++--------------- colossalai/shardformer/layer/linear.py | 18 +-- colossalai/shardformer/modeling/llama.py | 30 +--- .../gpt/hybridparallelism/finetune.py | 7 +- 5 files changed, 36 insertions(+), 176 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index a7ed61252..867de839e 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -55,7 +55,7 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt return ret.to(ret_type) -def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", group=None) -> None: +def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_reduce but during communication the data is cast to fp8 format. @@ -167,7 +167,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: del inp["fp8_scale"] -def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e4m3") -> None: +def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None: r""" This is an in-place operation for compressed reduce_scatter using fp8. It works like dist.reduce_scatter but during communication the data is cast to fp8 format. diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index c1a04357e..604a154d0 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -170,7 +170,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: handle.wait() - return grad_input, grad_weight, grad_bias, None, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): @@ -261,7 +261,7 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function): dist.reduce_scatter(output, grad_list, group=process_group) - return output, None, None, None + return output, None, None class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): @@ -729,7 +729,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function): grad_output = grad_output * ctx.grad_scale # to_cast.append(grad_output.cpu().detach().numpy()) - return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, "e4m3"), None, None, None, None + return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None class _ReduceForward(torch.autograd.Function): @@ -786,7 +786,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function): ctx.dim = dim ctx.grad_scale = grad_scale - return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3") + return _gather(input_, dim, process_group, fp8_communication=fp8_communication) @staticmethod def backward(ctx, grad_output): @@ -806,67 +806,26 @@ class _AllToAll(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication): + def forward(ctx, input_, process_group, scatter_dim, gather_dim): ctx.process_group = process_group ctx.scatter_dim = scatter_dim ctx.gather_dim = gather_dim - ctx.fp8_communication = fp8_communication world_size = dist.get_world_size(process_group) bsz, _, _ = input_.shape # using all_to_all_single when batch size is 1 if bsz == 1: - return _all_to_all_single( - input_, - world_size, - process_group, - scatter_dim, - gather_dim, - fp8_communication=fp8_communication, - fp8_format="e5m2", - ) + return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim) else: - return _all_to_all( - input_, - world_size, - process_group, - scatter_dim, - gather_dim, - fp8_communication=fp8_communication, - fp8_format="e5m2", - ) + return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) @staticmethod - def backward(ctx, grad_output): + def backward(ctx, *grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim - fp8_communication = ctx.fp8_communication - world_size = dist.get_world_size(process_group) - bsz, _, _ = grad_output.shape - - if bsz == 1: - return_grad = _all_to_all_single( - grad_output, - world_size, - process_group, - scatter_dim, - gather_dim, - fp8_communication=fp8_communication, - fp8_format="e5m2", - ) - else: - return_grad = _all_to_all( - grad_output, - world_size, - process_group, - scatter_dim, - gather_dim, - fp8_communication=fp8_communication, - fp8_format="e5m2", - ) - - return (return_grad, None, None, None, None) + return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) + return (return_grad, None, None, None) class HookParameter(torch.autograd.Function): @@ -924,41 +883,20 @@ def _split(input_, dim=-1, process_group=None): return output -def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e4m3"): +def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"): # skip if only one rank involved world_size = dist.get_world_size(process_group) if world_size == 1: return input_ - # all gather - import torch.distributed as dista - - from colossalai.zero.low_level._utils import has_inf_or_nan - if fp8_communication: - # if False: - if has_inf_or_nan(input_): - print("input has nan") - exit(0) input_type = input_.dtype - ret, scale = cast_to_fp8(input_, fp8_format="e5m2") - if has_inf_or_nan(ret): - import pdb - - pdb.set_trace() - print("cast has nan") - # exit(0) - dista.barrier() + ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) fp8_type = ret.dtype input_ = ret.view(torch.uint8) input_ = input_.contiguous() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] scale = torch.tensor(scale, dtype=torch.float32).to(input_.device) - # import torch.distributed as dista - # if dista.get_rank()==0: - # import pdb - # pdb.set_trace() - # dista.barrier() scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)] scale = torch.tensor(scale).to(input_.device) @@ -969,24 +907,10 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for for output, scale in zip(tensor_list, scale_list): output = output.view(fp8_type) output = cast_from_fp8(output, scale, input_type) - if has_inf_or_nan(output) and dista.get_rank() == 0: - print("casted_output has nan") - import pdb - - pdb.set_trace() - dista.barrier() - cast_tensor_list.append(output) output = torch.cat(cast_tensor_list, dim=dim).contiguous() - if has_inf_or_nan(output): - print("output has nan") - exit(0) - # import pdb - # pdb.set_trace() - dista.barrier() - else: input_ = input_.contiguous() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] @@ -1020,33 +944,14 @@ def _reduce_scatter(input_, dim=1, process_group=None): return output -def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"): - if fp8_communication: - input_type = input_.dtype - ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) - fp8_type = ret.dtype - input_ = ret.view(torch.uint8) - input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] - output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] - scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] - dist.all_to_all(output_list, input_list, group=group) - dist.all_gather(scale_list, scale, group=group) - cast_tensor_list = [] - for output, scale in zip(output_list, scale_list): - output = output.view(fp8_type) - output = cast_from_fp8(output, scale, input_type) - cast_tensor_list.append(output) - output_list = cast_tensor_list - else: - input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] - output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] - dist.all_to_all(output_list, input_list, group=group) +def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) return torch.cat(output_list, dim=gather_dim).contiguous() -def _all_to_all_single( - input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2" -): +def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): inp_shape = list(input_.shape) inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size if scatter_dim < 2: @@ -1058,24 +963,8 @@ def _all_to_all_single( .contiguous() ) - if fp8_communication: - input_type = input_t.dtype - ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format) - fp8_type = ret.dtype - input_t = ret.view(torch.uint8) - output = torch.empty_like(input_t) - scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(seq_world_size)] - dist.all_to_all_single(output, input_t, group=group) - dist.all_gather(scale_list, scale, group=group) - cast_tensor_list = [] - for output_part, scale in zip(output, scale_list): - output_part = output_part.view(fp8_type) - output_part = cast_from_fp8(output_part, scale, input_type) - cast_tensor_list.append(output_part) - output = torch.stack(cast_tensor_list, dim=0) - else: - output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=group) + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) if scatter_dim < 2: output = output.transpose(0, 1).contiguous() @@ -1143,5 +1032,5 @@ def reduce_backward(input_, process_group, fp8_communication=False): return _ReduceBackward.apply(input_, process_group, fp8_communication) -def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False): - return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) +def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index af25c398b..37c754241 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -84,7 +84,6 @@ class Linear1D_Col(ParallelModule): bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - fp8_communication: bool = False, **kwargs, ): super().__init__(weight=weight, bias_=bias_, **kwargs) @@ -99,7 +98,6 @@ class Linear1D_Col(ParallelModule): self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group - self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -203,12 +201,10 @@ class Linear1D_Col(ParallelModule): bias = self.bias if not self.skip_bias_add else None if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm( - input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication - ) + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) elif self.seq_parallel_mode == "split_gather": input_parallel = gather_forward_reducescatter_backward( - input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication + input_parallel, self.process_group, self.seq_parallel_dim ) output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) elif self.seq_parallel_mode == "ring": @@ -268,7 +264,6 @@ class Linear1D_Row(ParallelModule): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1, - fp8_communication: bool = False, ): super().__init__() @@ -283,7 +278,6 @@ class Linear1D_Row(ParallelModule): self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) - self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -404,9 +398,7 @@ class Linear1D_Row(ParallelModule): ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions ) - input_ = split_forward_gather_backward( - input_, dim=-1, process_group=self.process_group, fp8_comm=self.fp8_communication - ) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) if self.stream_chunk_num > 1: if self.training: @@ -426,11 +418,11 @@ class Linear1D_Row(ParallelModule): else: if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) - output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) + output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "split_gather": output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reducescatter_forward_gather_backward( - output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication + output_parallel, self.process_group, self.seq_parallel_dim ) elif self.seq_parallel_mode == "ring": output = linear_reducescatter_forward_gather_backward( diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 26d1b6e4c..bf5ce45a8 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -460,7 +460,7 @@ class LlamaPipelineForwards: return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): +def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): def forward( self, hidden_states: torch.Tensor, @@ -592,7 +592,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s return forward -def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): +def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) def forward( @@ -659,18 +659,9 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward( - inputs_embeds, - 1, - sp_group, - ) + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward( - inputs_embeds, - 1, - sp_group, - 1 / sp_size, - ) + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) hidden_states = inputs_embeds # decoder layers @@ -715,18 +706,9 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward( - hidden_states, - 1, - sp_group, - ) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, - 1, - sp_group, - grad_scale=sp_size, - ) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index a4b48fa05..9b3a10160 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -218,11 +218,8 @@ def main(): elif args.plugin == "hybrid_parallel": # modify the param accordingly for finetuning test cases plugin = HybridParallelPlugin( - tp_size=2, - pp_size=1, - sp_size=1, - # sequence_parallelism_mode="split_gather", - # enable_sequence_parallelism=True, + tp_size=1, + pp_size=2, num_microbatches=None, microbatch_size=1, enable_all_optimization=True,