diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index b818209a6..c210ca91e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -945,7 +945,8 @@ class HybridParallelPlugin(PipelinePluginBase): gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. - overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism + overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism. + fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism """ def __init__( @@ -1119,6 +1120,7 @@ class HybridParallelPlugin(PipelinePluginBase): parallel_output=parallel_output, make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, + fp8_communication=fp8_communication, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index e514f435e..867de839e 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -12,7 +12,6 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied. fp8_format: e4m3 or e5m2 - Returns: Tuples: A tuple (fp8_tensor, scale) """ @@ -39,12 +38,10 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: r""" - Args: inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2]. scale: scaling factor returned by cast_to_fp8 function. ret_type: the datatype of the returned tensor. - Returns: torch.Tensor """ @@ -58,20 +55,18 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt return ret.to(ret_type) -def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: +def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_reduce but during communication the data is cast to fp8 format. - Args: tensor: torch.Tensor in fp32, fp16, bf16 datatype. fp8_format: e4m3 or e5m2 - Returns: None """ - world_size = dist.get_world_size() + world_size = dist.get_world_size(group=group) input_type = tensor.dtype input_shape = tensor.shape input_device = tensor.device @@ -88,19 +83,19 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)] else: output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] - dist.all_to_all(output_chunks, input_chunks) + dist.all_to_all(output_chunks, input_chunks, group=group) scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] - dist.all_gather(scale_list, scale) + dist.all_gather(scale_list, scale, group=group) summed_out = torch.zeros_like(output_chunks[0]).to(input_type) for scale, out in zip(scale_list, output_chunks): out = out.view(fp8_type) summed_out += cast_from_fp8(out, scale, input_type) summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) - dist.all_gather(scale_list, scale) + dist.all_gather(scale_list, scale, group=group) tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0)) - dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8)) + dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group) for i in range(world_size): tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] tensor_out = torch.cat(tensor_list, dim=0) @@ -170,3 +165,40 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: if del_metadata: del inp["fp8_scale"] + + +def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None: + r""" + This is an in-place operation for compressed reduce_scatter using fp8. + It works like dist.reduce_scatter but during communication the data is cast to fp8 format. + + Args: + tensor: torch.Tensor in fp32, fp16, bf16 datatype. + fp8_format: e4m3 or e5m2 + + Returns: + None + """ + + input_type = output.dtype + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + scale_list = [] + cast_input_list = [] + output_chunks = [] + output_scale_list = [] + for input in input_list: + ret, scale = cast_to_fp8(input, fp8_format=fp8_format) + scale_list.append(scale) + ret = ret.view(torch.uint8) + cast_input_list.append(ret) + output_chunks.append(torch.empty_like(ret)) + output_scale_list.append(torch.empty_like(scale)) + dist.all_to_all(output_chunks, cast_input_list, group=group) + dist.all_to_all(output_scale_list, scale_list, group=group) + + summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + for scale, out in zip(output_scale_list, output_chunks): + out = out.view(fp8_type) + summed_out += cast_from_fp8(out, scale, input_type) + output.data = summed_out diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 82d37bb4c..604a154d0 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -14,6 +14,8 @@ try: except ImportError: _grad_accum_fusion_available = False +from colossalai.quantization.fp8 import all_reduce_fp8, cast_from_fp8, cast_to_fp8, reduce_scatter_fp8 + class FusedLayerNormAffineFunction1D(torch.autograd.Function): r"""Layernorm @@ -59,11 +61,12 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce + ctx.fp8_communication = fp8_communication output = torch.matmul(input_, weight) @@ -76,6 +79,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias + fp8_communication = ctx.fp8_communication # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. weight = weight.view(weight.shape) @@ -90,7 +94,9 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1]) - if ctx.async_grad_allreduce: + if ctx.async_grad_allreduce and fp8_communication: + _reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication) + elif ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have @@ -99,10 +105,10 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None - if ctx.async_grad_allreduce: + if ctx.async_grad_allreduce and not fp8_communication: handle.wait() - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None class LinearWithAsyncCommunication(torch.autograd.Function): @@ -242,7 +248,6 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function): def backward(ctx, grad_output): dim = ctx.dim process_group = ctx.process_group - # do reduce-scatter new_shape = list(grad_output.shape) assert ( @@ -253,6 +258,7 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function): item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim) ] output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) + dist.reduce_scatter(output, grad_list, group=process_group) return output, None, None @@ -546,9 +552,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, dim): + def forward(ctx, input_, process_group, dim, fp8_communication=False): ctx.dim = dim ctx.process_group = process_group + ctx.fp8_communication = fp8_communication # do reduce-scatter new_shape = list(input_.shape) @@ -558,7 +565,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function): new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) - dist.reduce_scatter(output, input_list, group=process_group) + if fp8_communication: + reduce_scatter_fp8(output, input_list, group=process_group) + else: + dist.reduce_scatter(output, input_list, group=process_group) return output @@ -566,8 +576,8 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function): def backward(ctx, grad_output): dim = ctx.dim process_group = ctx.process_group - - return _gather(grad_output, dim, process_group), None, None + fp8_communication = ctx.fp8_communication + return _gather(grad_output, dim, process_group, fp8_communication=fp8_communication), None, None, None class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): @@ -582,13 +592,16 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring): + def forward( + ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication + ): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim ctx.overlap = overlap + ctx.fp8_communication = fp8_communication if ring is True: input_to_gather = {} @@ -605,7 +618,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): ) else: - input_parallel = _gather(input_, dim, process_group) + input_parallel = _gather(input_, dim, process_group, fp8_communication) output = torch.matmul(input_parallel, weight) @@ -620,6 +633,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): dim = ctx.dim process_group = ctx.process_group overlap = ctx.overlap + fp8_communication = ctx.fp8_communication # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm weight = weight.view(weight.shape) @@ -627,7 +641,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): bias = bias.view(bias.shape) if not overlap: - input_parallel = _gather(input_, dim, process_group) + input_parallel = _gather(input_, dim, process_group, fp8_communication) total_input = input_parallel grad_input = grad_output.matmul(weight.T) @@ -687,7 +701,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -702,17 +716,20 @@ class _SplitForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group, grad_scale=None): + def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False): ctx.process_group = process_group ctx.dim = dim ctx.grad_scale = grad_scale + ctx.fp8_communication = fp8_communication return _split(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): if ctx.grad_scale is not None: grad_output = grad_output * ctx.grad_scale - return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None + + # to_cast.append(grad_output.cpu().detach().numpy()) + return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None class _ReduceForward(torch.autograd.Function): @@ -725,12 +742,12 @@ class _ReduceForward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group): - return _reduce(input_, process_group) + def forward(ctx, input_, process_group, fp8_communication=False): + return _reduce(input_, process_group, fp8_communication) @staticmethod def backward(ctx, grad_output): - return grad_output, None + return grad_output, None, None class _ReduceBackward(torch.autograd.Function): @@ -743,13 +760,15 @@ class _ReduceBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group): + def forward(ctx, input_, process_group, fp8_communication=False): ctx.process_group = process_group + ctx.fp8_communication = fp8_communication return input_ @staticmethod def backward(ctx, grad_output): - return _reduce(grad_output, ctx.process_group), None + fp8_communication = ctx.fp8_communication + return _reduce(grad_output, ctx.process_group, fp8_communication), None, None class _GatherForwardSplitBackward(torch.autograd.Function): @@ -762,17 +781,18 @@ class _GatherForwardSplitBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group, grad_scale=None): + def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False): ctx.process_group = process_group ctx.dim = dim ctx.grad_scale = grad_scale - return _gather(input_, dim, process_group) + + return _gather(input_, dim, process_group, fp8_communication=fp8_communication) @staticmethod def backward(ctx, grad_output): if ctx.grad_scale is not None: grad_output = grad_output * ctx.grad_scale - return _split(grad_output, ctx.dim, ctx.process_group), None, None, None + return _split(grad_output, ctx.dim, ctx.process_group), None, None, None, None class _AllToAll(torch.autograd.Function): @@ -831,12 +851,15 @@ def hook_parameter_in_backward(input, weight=None, bias=None): return HookParameter.apply(input, weight, bias) -def _reduce(input_, process_group): +def _reduce(input_, process_group, fp8_communication=False): # skip if only one rank involved if dist.get_world_size(process_group) == 1: return input_ else: - dist.all_reduce(input_, group=process_group) + if fp8_communication: + all_reduce_fp8(input_, group=process_group) + else: + dist.all_reduce(input_, group=process_group) return input_ @@ -860,19 +883,39 @@ def _split(input_, dim=-1, process_group=None): return output -def _gather(input_, dim=-1, process_group=None): +def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"): # skip if only one rank involved world_size = dist.get_world_size(process_group) if world_size == 1: return input_ - # all gather - input_ = input_.contiguous() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - torch.distributed.all_gather(tensor_list, input_, group=process_group) + if fp8_communication: + input_type = input_.dtype + ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) + fp8_type = ret.dtype + input_ = ret.view(torch.uint8) + input_ = input_.contiguous() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + scale = torch.tensor(scale, dtype=torch.float32).to(input_.device) + scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)] - # concat - output = torch.cat(tensor_list, dim=dim).contiguous() + scale = torch.tensor(scale).to(input_.device) + torch.distributed.all_gather(tensor_list, input_, group=process_group) + torch.distributed.all_gather(scale_list, scale, group=process_group) + + cast_tensor_list = [] + for output, scale in zip(tensor_list, scale_list): + output = output.view(fp8_type) + output = cast_from_fp8(output, scale, input_type) + cast_tensor_list.append(output) + + output = torch.cat(cast_tensor_list, dim=dim).contiguous() + + else: + input_ = input_.contiguous() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, input_, group=process_group) + output = torch.cat(tensor_list, dim=dim).contiguous() return output @@ -935,8 +978,10 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): ).contiguous() -def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): - return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) +def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + return MatmulWithAsyncCommunication.apply( + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication + ) def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): @@ -955,8 +1000,8 @@ def gather_forward_reducescatter_backward(input_, process_group, dim): return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim) -def reducescatter_forward_gather_backward(input_, process_group, dim): - return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim) +def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False): + return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication) def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False): @@ -964,27 +1009,27 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False ): return _MatmulWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication ) -def gather_forward_split_backward(input_, dim, process_group, grad_scale=None): - return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale) +def gather_forward_split_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False): + return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) -def split_forward_gather_backward(input_, dim, process_group, grad_scale=None): - return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale) +def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False): + return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) -def reduce_forward(input_, process_group): - return _ReduceForward.apply(input_, process_group) +def reduce_forward(input_, process_group, fp8_communication=False): + return _ReduceForward.apply(input_, process_group, fp8_communication) -def reduce_backward(input_, process_group): - return _ReduceBackward.apply(input_, process_group) +def reduce_backward(input_, process_group, fp8_communication=False): + return _ReduceBackward.apply(input_, process_group, fp8_communication) def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 0f6595a7c..d8425b58d 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -183,6 +183,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, ): super().__init__() @@ -197,6 +198,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.n_fused = n_fused self.process_group = process_group self.async_communication = async_communication + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -314,14 +316,26 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): if self.seq_parallel_mode is None: # Set up backprop all-reduce. - input_parallel = reduce_backward(input_, self.process_group) + input_parallel = reduce_backward(input_, self.process_group, fp8_communication=self.fp8_communication) output_parallel = matmul_with_async_comm( - input_parallel, self.weight, bias, self.process_group, self.async_communication + input_parallel, + self.weight, + bias, + self.process_group, + self.async_communication, + fp8_communication=self.fp8_communication, ) elif self.seq_parallel_mode == "split_gather": input_parallel = input_ output_parallel = matmul_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap + input_parallel, + self.weight, + bias, + self.process_group, + True, + 1, + self.overlap, + fp8_communication=self.fp8_communication, ) elif self.seq_parallel_mode == "ring": input_parallel = input_ @@ -331,7 +345,9 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): if self.gather_output: # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + output = gather_forward_split_backward( + output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) else: output = output_parallel @@ -379,6 +395,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1, + fp8_communication: bool = False, ): super().__init__() @@ -392,6 +409,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): self.process_group = process_group self.seq_parallel_mode = seq_parallel_mode self.num_partitions = dist.get_world_size(self.process_group) + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -514,7 +532,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions ) - input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + input_ = split_forward_gather_backward( + input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) if self.stream_chunk_num > 1: if self.training: @@ -535,13 +555,20 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): else: if self.seq_parallel_mode is None: output_parallel = torch.matmul(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + output = reduce_forward(output_parallel, self.process_group, self.fp8_communication) elif self.seq_parallel_mode == "split_gather": output_parallel = torch.matmul(input_, self.weight) - output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + output = reducescatter_forward_gather_backward( + output_parallel, + self.process_group, + 1, + self.fp8_communication, + ) elif self.seq_parallel_mode == "ring": output_parallel = torch.matmul(input_, self.weight) - output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + output = reducescatter_forward_gather_backward( + output_parallel, self.process_group, 1, self.fp8_communication + ) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index aa75bab11..beaa47952 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -1137,6 +1137,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, dim=1, process_group=shard_config.sequence_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -1204,6 +1205,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, dim=1, process_group=shard_config.sequence_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) hidden_states = self.ln_f(hidden_states) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index cfe20000a..bb6269737 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -110,14 +110,13 @@ class GPT2Policy(Policy): "n_fused": 3, "seq_parallel_mode": sp_mode, "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel_mode": sp_mode, - }, + kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="mlp.c_fc", @@ -127,14 +126,13 @@ class GPT2Policy(Policy): "seq_parallel_mode": sp_mode, "overlap": overlap, "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel_mode": sp_mode, - }, + kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index b64300366..7372e06c2 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -29,6 +29,7 @@ class ShardConfig: enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. + fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None @@ -47,6 +48,7 @@ class ShardConfig: gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) ep_group: Optional[ProcessGroup] = None + fp8_communication: bool = False # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']