diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index a7ed61252..867de839e 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -55,7 +55,7 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt return ret.to(ret_type) -def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", group=None) -> None: +def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_reduce but during communication the data is cast to fp8 format. @@ -167,7 +167,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: del inp["fp8_scale"] -def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e4m3") -> None: +def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None: r""" This is an in-place operation for compressed reduce_scatter using fp8. It works like dist.reduce_scatter but during communication the data is cast to fp8 format. diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index c1a04357e..604a154d0 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -170,7 +170,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: handle.wait() - return grad_input, grad_weight, grad_bias, None, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): @@ -261,7 +261,7 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function): dist.reduce_scatter(output, grad_list, group=process_group) - return output, None, None, None + return output, None, None class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): @@ -729,7 +729,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function): grad_output = grad_output * ctx.grad_scale # to_cast.append(grad_output.cpu().detach().numpy()) - return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, "e4m3"), None, None, None, None + return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None class _ReduceForward(torch.autograd.Function): @@ -786,7 +786,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function): ctx.dim = dim ctx.grad_scale = grad_scale - return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3") + return _gather(input_, dim, process_group, fp8_communication=fp8_communication) @staticmethod def backward(ctx, grad_output): @@ -806,67 +806,26 @@ class _AllToAll(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication): + def forward(ctx, input_, process_group, scatter_dim, gather_dim): ctx.process_group = process_group ctx.scatter_dim = scatter_dim ctx.gather_dim = gather_dim - ctx.fp8_communication = fp8_communication world_size = dist.get_world_size(process_group) bsz, _, _ = input_.shape # using all_to_all_single when batch size is 1 if bsz == 1: - return _all_to_all_single( - input_, - world_size, - process_group, - scatter_dim, - gather_dim, - fp8_communication=fp8_communication, - fp8_format="e5m2", - ) + return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim) else: - return _all_to_all( - input_, - world_size, - process_group, - scatter_dim, - gather_dim, - fp8_communication=fp8_communication, - fp8_format="e5m2", - ) + return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) @staticmethod - def backward(ctx, grad_output): + def backward(ctx, *grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim - fp8_communication = ctx.fp8_communication - world_size = dist.get_world_size(process_group) - bsz, _, _ = grad_output.shape - - if bsz == 1: - return_grad = _all_to_all_single( - grad_output, - world_size, - process_group, - scatter_dim, - gather_dim, - fp8_communication=fp8_communication, - fp8_format="e5m2", - ) - else: - return_grad = _all_to_all( - grad_output, - world_size, - process_group, - scatter_dim, - gather_dim, - fp8_communication=fp8_communication, - fp8_format="e5m2", - ) - - return (return_grad, None, None, None, None) + return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) + return (return_grad, None, None, None) class HookParameter(torch.autograd.Function): @@ -924,41 +883,20 @@ def _split(input_, dim=-1, process_group=None): return output -def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e4m3"): +def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"): # skip if only one rank involved world_size = dist.get_world_size(process_group) if world_size == 1: return input_ - # all gather - import torch.distributed as dista - - from colossalai.zero.low_level._utils import has_inf_or_nan - if fp8_communication: - # if False: - if has_inf_or_nan(input_): - print("input has nan") - exit(0) input_type = input_.dtype - ret, scale = cast_to_fp8(input_, fp8_format="e5m2") - if has_inf_or_nan(ret): - import pdb - - pdb.set_trace() - print("cast has nan") - # exit(0) - dista.barrier() + ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) fp8_type = ret.dtype input_ = ret.view(torch.uint8) input_ = input_.contiguous() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] scale = torch.tensor(scale, dtype=torch.float32).to(input_.device) - # import torch.distributed as dista - # if dista.get_rank()==0: - # import pdb - # pdb.set_trace() - # dista.barrier() scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)] scale = torch.tensor(scale).to(input_.device) @@ -969,24 +907,10 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for for output, scale in zip(tensor_list, scale_list): output = output.view(fp8_type) output = cast_from_fp8(output, scale, input_type) - if has_inf_or_nan(output) and dista.get_rank() == 0: - print("casted_output has nan") - import pdb - - pdb.set_trace() - dista.barrier() - cast_tensor_list.append(output) output = torch.cat(cast_tensor_list, dim=dim).contiguous() - if has_inf_or_nan(output): - print("output has nan") - exit(0) - # import pdb - # pdb.set_trace() - dista.barrier() - else: input_ = input_.contiguous() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] @@ -1020,33 +944,14 @@ def _reduce_scatter(input_, dim=1, process_group=None): return output -def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"): - if fp8_communication: - input_type = input_.dtype - ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) - fp8_type = ret.dtype - input_ = ret.view(torch.uint8) - input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] - output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] - scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] - dist.all_to_all(output_list, input_list, group=group) - dist.all_gather(scale_list, scale, group=group) - cast_tensor_list = [] - for output, scale in zip(output_list, scale_list): - output = output.view(fp8_type) - output = cast_from_fp8(output, scale, input_type) - cast_tensor_list.append(output) - output_list = cast_tensor_list - else: - input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] - output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] - dist.all_to_all(output_list, input_list, group=group) +def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) return torch.cat(output_list, dim=gather_dim).contiguous() -def _all_to_all_single( - input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2" -): +def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): inp_shape = list(input_.shape) inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size if scatter_dim < 2: @@ -1058,24 +963,8 @@ def _all_to_all_single( .contiguous() ) - if fp8_communication: - input_type = input_t.dtype - ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format) - fp8_type = ret.dtype - input_t = ret.view(torch.uint8) - output = torch.empty_like(input_t) - scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(seq_world_size)] - dist.all_to_all_single(output, input_t, group=group) - dist.all_gather(scale_list, scale, group=group) - cast_tensor_list = [] - for output_part, scale in zip(output, scale_list): - output_part = output_part.view(fp8_type) - output_part = cast_from_fp8(output_part, scale, input_type) - cast_tensor_list.append(output_part) - output = torch.stack(cast_tensor_list, dim=0) - else: - output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=group) + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) if scatter_dim < 2: output = output.transpose(0, 1).contiguous() @@ -1143,5 +1032,5 @@ def reduce_backward(input_, process_group, fp8_communication=False): return _ReduceBackward.apply(input_, process_group, fp8_communication) -def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False): - return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) +def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index af25c398b..37c754241 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -84,7 +84,6 @@ class Linear1D_Col(ParallelModule): bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - fp8_communication: bool = False, **kwargs, ): super().__init__(weight=weight, bias_=bias_, **kwargs) @@ -99,7 +98,6 @@ class Linear1D_Col(ParallelModule): self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group - self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -203,12 +201,10 @@ class Linear1D_Col(ParallelModule): bias = self.bias if not self.skip_bias_add else None if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm( - input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication - ) + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) elif self.seq_parallel_mode == "split_gather": input_parallel = gather_forward_reducescatter_backward( - input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication + input_parallel, self.process_group, self.seq_parallel_dim ) output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) elif self.seq_parallel_mode == "ring": @@ -268,7 +264,6 @@ class Linear1D_Row(ParallelModule): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1, - fp8_communication: bool = False, ): super().__init__() @@ -283,7 +278,6 @@ class Linear1D_Row(ParallelModule): self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) - self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -404,9 +398,7 @@ class Linear1D_Row(ParallelModule): ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions ) - input_ = split_forward_gather_backward( - input_, dim=-1, process_group=self.process_group, fp8_comm=self.fp8_communication - ) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) if self.stream_chunk_num > 1: if self.training: @@ -426,11 +418,11 @@ class Linear1D_Row(ParallelModule): else: if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) - output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) + output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "split_gather": output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reducescatter_forward_gather_backward( - output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication + output_parallel, self.process_group, self.seq_parallel_dim ) elif self.seq_parallel_mode == "ring": output = linear_reducescatter_forward_gather_backward( diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 26d1b6e4c..bf5ce45a8 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -460,7 +460,7 @@ class LlamaPipelineForwards: return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): +def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): def forward( self, hidden_states: torch.Tensor, @@ -592,7 +592,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s return forward -def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): +def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) def forward( @@ -659,18 +659,9 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward( - inputs_embeds, - 1, - sp_group, - ) + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward( - inputs_embeds, - 1, - sp_group, - 1 / sp_size, - ) + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) hidden_states = inputs_embeds # decoder layers @@ -715,18 +706,9 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward( - hidden_states, - 1, - sp_group, - ) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward( - hidden_states, - 1, - sp_group, - grad_scale=sp_size, - ) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index a4b48fa05..9b3a10160 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -218,11 +218,8 @@ def main(): elif args.plugin == "hybrid_parallel": # modify the param accordingly for finetuning test cases plugin = HybridParallelPlugin( - tp_size=2, - pp_size=1, - sp_size=1, - # sequence_parallelism_mode="split_gather", - # enable_sequence_parallelism=True, + tp_size=1, + pp_size=2, num_microbatches=None, microbatch_size=1, enable_all_optimization=True,