diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 56fa006b1..7a906738c 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -9,6 +9,7 @@ on: paths: - "examples/**" - "!examples/**.md" + - ".github/workflows/example_check_on_pr.yml" jobs: # This is for changed example files detect and output a matrix containing all the corresponding directory name. @@ -107,7 +108,7 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v -e . - name: Store Colossal-AI Cache run: | diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 3754cfe60..6a5d0c161 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -366,7 +366,9 @@ class GeminiPlugin(DPPluginBase): enable_jit_fused: bool = False, enable_sequence_overlap: bool = False, enable_async_reduce: bool = True, + use_fp8: bool = False, verbose: bool = False, + fp8_communication: bool = False, ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" @@ -401,6 +403,8 @@ class GeminiPlugin(DPPluginBase): master_weights=master_weights, max_prefetch=max_prefetch, enable_async_reduce=enable_async_reduce, + fp8_communication=fp8_communication, + use_fp8=use_fp8, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index b4b40020f..ae1fbc771 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -31,6 +31,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.quantization.fp8_hook import FP8Hook from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp from colossalai.shardformer.policies.base_policy import Policy @@ -66,6 +67,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): ddp_config: dict, custom_policy: Policy, overlap_allgather: bool = False, + use_fp8: bool = False, ) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.shard_config = shard_config @@ -75,6 +77,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): self.use_ddp = use_ddp self.require_grad_sync = True self.overlap_allgather = overlap_allgather + self.use_fp8 = use_fp8 shardformer = ShardFormer(shard_config) if custom_policy is not None: @@ -112,6 +115,9 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): module = DDP(module, process_group=dp_group, **ddp_config) super().__init__(module) + self.op_hooks = [] + if use_fp8: + self.op_hooks.append(FP8Hook()) if overlap_allgather: self.op_hook = ZeroOpHook() for p in module.parameters(): @@ -223,7 +229,11 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): wait_all_gather_handle(p) def _wait_all_gather(self): - return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() + return ( + ColoParamOpHookManager.use_hooks(*self.op_hooks) + if (self.overlap_allgather or self.use_fp8) + else nullcontext() + ) def get_param_info(optim: Optimizer): @@ -969,6 +979,7 @@ class HybridParallelPlugin(PipelinePluginBase): gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. + fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn". It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default. @@ -1020,6 +1031,8 @@ class HybridParallelPlugin(PipelinePluginBase): dp_outside: bool = True, overlap_p2p: bool = True, overlap_allgather: bool = False, + fp8_communication: bool = False, + use_fp8: bool = False, inner_ring_size: int = None, ) -> None: super().__init__() @@ -1069,8 +1082,10 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism + self.use_fp8 = use_fp8 if dp_outside: self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) if sequence_parallelism_mode == "ring_attn": # Swap tp and sp since 2D Ring has better inter-node latency self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size) @@ -1117,6 +1132,7 @@ class HybridParallelPlugin(PipelinePluginBase): microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, overlap_p2p=overlap_p2p, + fp8_communication=fp8_communication, ) elif pp_style == "1f1b": self.schedule = OneForwardOneBackwardSchedule( @@ -1124,6 +1140,7 @@ class HybridParallelPlugin(PipelinePluginBase): num_microbatches=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, + fp8_communication=fp8_communication, ) else: raise NotImplementedError() @@ -1158,6 +1175,7 @@ class HybridParallelPlugin(PipelinePluginBase): parallel_output=parallel_output, make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, + fp8_communication=fp8_communication, inner_ring_size=inner_ring_size, ) self.amp_config = dict( @@ -1250,7 +1268,7 @@ class HybridParallelPlugin(PipelinePluginBase): use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( self.dp_size == 1 and self.pp_size == 1 ) - + # sync gradients across DP * SP ranks # Apply Hybrid ZeRO across DP * SP ranks if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) @@ -1268,6 +1286,7 @@ class HybridParallelPlugin(PipelinePluginBase): ddp_config=self.ddp_config, custom_policy=self.custom_policy, overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), + use_fp8=self.use_fp8, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if zero_stage == 0: diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 185d34f12..42bb49bc9 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -34,6 +34,7 @@ from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.quantization.fp8_hook import FP8Hook from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero import LowLevelZeroOptimizer @@ -62,7 +63,12 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): def __init__( - self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True + self, + module: nn.Module, + precision: str, + overlap_allgather: bool = False, + cast_inputs: bool = True, + use_fp8: bool = False, ) -> None: super().__init__(module) self.dtype = None @@ -75,11 +81,16 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None + self.use_fp8 = use_fp8 if self.dtype is not None and cast_inputs: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather + self.op_hooks = [] if overlap_allgather: - self.op_hook = ZeroOpHook() + self.op_hooks.append(ZeroOpHook()) + if use_fp8: + self.op_hooks.append(FP8Hook()) + if overlap_allgather or use_fp8: for p in module.parameters(): if p.requires_grad and type(p) is not ColoParameter: p.__class__ = ColoParameter @@ -337,6 +348,8 @@ class LowLevelZeroPlugin(DPPluginBase): master_weights: bool = True, verbose: bool = False, cast_inputs: bool = True, + fp8_communication: bool = False, + use_fp8: bool = False, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -360,12 +373,14 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload=cpu_offload, master_weights=master_weights, overlap_allgather=overlap_allgather, + fp8_communication=fp8_communication, ) self.lora_enabled = False self.verbose = verbose self.logger = get_dist_logger() self.cast_inputs = cast_inputs + self.use_fp8 = use_fp8 # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -484,6 +499,7 @@ class LowLevelZeroPlugin(DPPluginBase): self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], cast_inputs=self.cast_inputs, + use_fp8=self.use_fp8, ) # TODO: Support Galore + ZeRO diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 36973b240..74d35f5c5 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -214,6 +214,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): moe_dp_outside: bool = True, overlap_p2p: bool = True, overlap_allgather: bool = False, + fp8_communication: bool = False, + use_fp8: bool = False, ) -> None: self.logger = get_dist_logger() if overlap_communication or zero_stage == 2: @@ -327,6 +329,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) else: self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + self.use_fp8 = use_fp8 self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, @@ -345,6 +348,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): parallel_output=parallel_output, make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, + fp8_communication=fp8_communication, ) self.amp_config = dict( initial_scale=initial_scale, @@ -431,6 +435,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, + use_fp8=self.use_fp8, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.ep_size > 1: diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 8a807970c..61d785a4c 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -179,6 +179,7 @@ class TorchDDPPlugin(DPPluginBase): check_reduction: bool = False, gradient_as_bucket_view: bool = False, static_graph: bool = False, + fp8_communication: bool = False, ) -> None: super().__init__() self.ddp_kwargs = dict( @@ -189,6 +190,7 @@ class TorchDDPPlugin(DPPluginBase): gradient_as_bucket_view=gradient_as_bucket_view, static_graph=static_graph, ) + self.fp8_communication = fp8_communication def support_no_sync(self) -> bool: return True @@ -228,6 +230,11 @@ class TorchDDPPlugin(DPPluginBase): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = OptimizerWrapper(optimizer) + if self.fp8_communication: + from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async + + model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async) + return model, optimizer, criterion, dataloader, lr_scheduler def control_checkpoint_io(self) -> bool: diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 7b67da032..23a35bbcb 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -298,6 +298,7 @@ class TorchFSDPPlugin(DPPluginBase): ignored_modules: Optional[Iterable[torch.nn.Module]] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None, sync_module_states: bool = False, + fp8_communication: bool = False, ): super().__init__() self.fsdp_kwargs = dict( @@ -311,6 +312,7 @@ class TorchFSDPPlugin(DPPluginBase): param_init_fn=param_init_fn, sync_module_states=sync_module_states, ) + self.fp8_communication = fp8_communication self.logger = get_dist_logger() else: @@ -348,6 +350,19 @@ class TorchFSDPPlugin(DPPluginBase): # wrap the model with PyTorch FSDP fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) + if self.fp8_communication: + from colossalai.quantization.utils import patch_fsdp_params_comm_hook + + patch_fsdp_params_comm_hook() + + from colossalai.quantization.fp8 import fp8_compress_fsdp_params_comm_hook + + fsdp_model.module.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook) + + from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook + + fsdp_model.module.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook) + if optimizer is not None: if len(optimizer.param_groups) > 1: self.logger.warning( diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index ac422a4da..ba087a03b 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -6,6 +6,8 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup +from colossalai.quantization.fp8 import all_to_all_single_fp8 + MOE_KERNEL = None @@ -380,6 +382,7 @@ def _all_to_all( output_split_sizes: Optional[List[int]] = None, group=None, async_op: bool = False, + fp8_communication: bool = False, ): """ Returns: @@ -392,9 +395,14 @@ def _all_to_all( outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device) inputs = inputs.contiguous() outputs = outputs.contiguous() - handle = dist.all_to_all_single( - outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op - ) + if fp8_communication: + handle = all_to_all_single_fp8( + outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False + ) + else: + handle = dist.all_to_all_single( + outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op + ) return outputs, handle @@ -407,6 +415,7 @@ class AllToAllUneven(torch.autograd.Function): output_split_sizes=None, group=None, overlap: bool = False, + fp8_communication: bool = False, ): """ Returns: @@ -416,7 +425,9 @@ class AllToAllUneven(torch.autograd.Function): ctx.input_split_sizes = input_split_sizes ctx.output_split_sizes = output_split_sizes ctx.group = group - return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap) + return _all_to_all( + inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication=fp8_communication + ) @staticmethod def backward(ctx: Any, *grad_outputs): @@ -426,6 +437,7 @@ class AllToAllUneven(torch.autograd.Function): None, None, None, + None, ) @@ -435,8 +447,9 @@ def all_to_all_uneven( output_split_sizes: Optional[List[int]] = None, group=None, overlap: bool = False, + fp8_communication: bool = False, ): assert ( inputs.requires_grad ), "Input must require grad to assure that backward is executed, otherwise it might hang the program." - return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) + return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 412f3896f..28c4eb8d8 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -11,6 +11,7 @@ from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline from colossalai.utils import get_current_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device @@ -32,6 +33,7 @@ class InterleavedSchedule(PipelineSchedule): microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, overlap_p2p: bool = True, + fp8_communication: bool = False, ) -> None: super().__init__(stage_manager) assert ( @@ -56,6 +58,8 @@ class InterleavedSchedule(PipelineSchedule): self.tensor_metadata_recv = None self.grad_metadata_recv = None + self.fp8_communication = fp8_communication + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -191,8 +195,12 @@ class InterleavedSchedule(PipelineSchedule): """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): + if self.fp8_communication: + cast_to_fp8_pipeline(output_tensor) send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) self.send_tensor_metadata = not self.enable_metadata_cache + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor) return send_handles return [] @@ -210,10 +218,14 @@ class InterleavedSchedule(PipelineSchedule): """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): + if self.fp8_communication: + cast_to_fp8_pipeline(input_tensor_grad) send_handles = self.comm.send_backward( input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata ) self.send_grad_metadata = not self.enable_metadata_cache + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor_grad) return send_handles return [] @@ -224,6 +236,8 @@ class InterleavedSchedule(PipelineSchedule): is_send = not self.stage_manager.is_last_stage() with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): is_recv = not self.stage_manager.is_first_stage() + if self.fp8_communication: + cast_to_fp8_pipeline(output_tensor) input_tensor, wait_handles = self.comm.send_forward_recv_forward( output_tensor, is_send, @@ -237,6 +251,8 @@ class InterleavedSchedule(PipelineSchedule): if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor) return input_tensor, wait_handles def send_backward_recv_backward( @@ -246,6 +262,8 @@ class InterleavedSchedule(PipelineSchedule): is_send = not self.stage_manager.is_first_stage() with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): is_recv = not self.stage_manager.is_last_stage() + if self.fp8_communication: + cast_to_fp8_pipeline(input_tensor_grad) output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward( input_tensor_grad, is_send, @@ -258,6 +276,8 @@ class InterleavedSchedule(PipelineSchedule): self.send_grad_metadata = not self.enable_metadata_cache and is_send if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor_grad) return output_tensor_grad, wait_handles def forward_step( @@ -378,6 +398,8 @@ class InterleavedSchedule(PipelineSchedule): # Wait until current input is received _wait_p2p(fwd_wait_handles) + if self.fp8_communication and input_obj is not None: + cast_from_fp8_pipeline(input_obj) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) if not last_batch: @@ -440,6 +462,8 @@ class InterleavedSchedule(PipelineSchedule): # Wait for input _wait_p2p(fwd_wait_handles) + if self.fp8_communication and input_obj is not None: + cast_from_fp8_pipeline(input_obj) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) @@ -466,6 +490,8 @@ class InterleavedSchedule(PipelineSchedule): # Wait for input. _wait_p2p(fwd_wait_handles) + if self.fp8_communication and input_obj is not None: + cast_from_fp8_pipeline(input_obj) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) # Add input_obj and output_obj to end of list. input_objs[model_chunk_id].append(input_obj) @@ -510,6 +536,8 @@ class InterleavedSchedule(PipelineSchedule): input_obj, fwd_wait_handles = send_forward_recv_forward() # Wait for upstream grad _wait_p2p(bwd_wait_handles) + if self.fp8_communication and output_obj_grad is not None: + cast_from_fp8_pipeline(output_obj_grad) input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) # NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv) # risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html) @@ -531,6 +559,8 @@ class InterleavedSchedule(PipelineSchedule): # Wait for upstream grad _wait_p2p(bwd_wait_handles) + if self.fp8_communication and output_obj_grad is not None: + cast_from_fp8_pipeline(output_obj_grad) # backward local grads input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) if not last_batch: diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 03df67ae7..67aaa5eb1 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline from colossalai.utils import get_current_device from ._utils import ( @@ -32,6 +33,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, + fp8_communication: bool = False, ) -> None: """1F1B pipeline schedule. @@ -61,6 +63,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.tensor_metadata_recv = None self.grad_metadata_recv = None + self.fp8_communication = fp8_communication + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -129,6 +133,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor) return input_tensor def recv_backward(self, next_rank: int = None) -> Any: @@ -143,6 +149,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): """ if not self.stage_manager.is_last_stage(): output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor_grad) if self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) @@ -157,9 +165,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): next_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_last_stage(): + if self.fp8_communication: + cast_to_fp8_pipeline(output_tensor) self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) self.send_tensor_metadata = not self.enable_metadata_cache + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor, del_metadata=False) + def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For 1F1B. @@ -169,8 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): prev_rank (int, optional): The rank of the recipient of the tensor """ if not self.stage_manager.is_first_stage(): + if self.fp8_communication: + cast_to_fp8_pipeline(input_tensor_grad) self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) self.send_grad_metadata = not self.enable_metadata_cache + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False) def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any: """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. @@ -183,6 +200,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if not self.stage_manager.is_last_stage(): if not self.send_tensor_metadata and self.grad_metadata_recv is not None: send_first = None + if self.fp8_communication: + cast_to_fp8_pipeline(output_tensor) output_tensor_grad, _ = self.comm.send_forward_recv_backward( output_tensor, send_metadata=self.send_tensor_metadata, @@ -192,6 +211,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.send_tensor_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor, del_metadata=False) + cast_from_fp8_pipeline(output_tensor_grad) return output_tensor_grad @@ -206,6 +228,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if not self.stage_manager.is_first_stage(): if not self.send_grad_metadata and self.tensor_metadata_recv is not None: send_first = None # must not fallback + if self.fp8_communication: + cast_to_fp8_pipeline(input_tensor_grad) input_tensor, _ = self.comm.send_backward_recv_forward( input_tensor_grad, send_metadata=self.send_grad_metadata, @@ -215,6 +239,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.send_grad_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor) + cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False) return input_tensor diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py new file mode 100644 index 000000000..5b606616e --- /dev/null +++ b/colossalai/quantization/fp8.py @@ -0,0 +1,738 @@ +from typing import Any, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from packaging.version import Version +from torch.distributed import ReduceOp + +SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0") + + +class Handle: + def __init__(self, handles=[], remain_ops=None) -> None: + self.handles = handles + self.remain_ops = remain_ops + + def wait(self): + for handle in self.handles: + handle.wait() + if self.remain_ops: + self.remain_ops() + + +def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. + Args: + inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor. + scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling + is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied. + fp8_format: e4m3 or e5m2 + + Returns: + Tuples: A tuple (fp8_tensor, scale) + """ + + if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise TypeError("Only float16, bfloat16, and float32 are allowed.") + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + fp8_max = torch.finfo(fp8_type).max + + if inp.numel() == 0: + return inp.to(fp8_type), torch.tensor([1.0], device=inp.device) + else: + if per_channel_scale: + per_channel_max = inp.abs().max(dim=-1).values.float() + per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) + scale = fp8_max / per_channel_max[:, None] + scale_inv = per_channel_max / fp8_max + else: + per_tensor_max = inp.abs().max().float() + per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) + scale = fp8_max / per_tensor_max + scale_inv = 1.0 / scale + + ret = (scale * inp.float()).to(fp8_type) + return ret, scale_inv + + +def cast_from_fp8( + inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False +) -> torch.Tensor: + r""" + Args: + inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2]. + scale: scaling factor returned by cast_to_fp8 function. + ret_type: the datatype of the returned tensor. + Returns: + torch.Tensor + """ + if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.") + + if per_channel_scale: + ret = scale_inv[:, None] * inp.float() + else: + ret = scale_inv * inp.float() + return ret.to(ret_type) + + +def all_reduce_fp8( + tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False +) -> Optional[Handle]: + r""" + This is an in-place operation for compressed all_reduce using fp8. + It works like dist.all_reduce but during communication the data is cast to fp8 format. + + Args: + tensor: torch.Tensor in fp32, fp16, bf16 datatype. + fp8_format: e4m3 or e5m2 + op: ReduceOp.SUM or ReduceOp.AVG + + Returns: + None + """ + + world_size = dist.get_world_size(group=group) + input_type = tensor.dtype + input_shape = tensor.shape + input_device = tensor.device + input_size = tensor.numel() + flat_padded_x = tensor.flatten() + + assert op in [ReduceOp.SUM, ReduceOp.AVG], "op can only be ReduceOp.SUM or ReduceOp.AVG" + + if flat_padded_x.size(0) % world_size != 0: + pad_size = world_size - flat_padded_x.size(0) % world_size + flat_padded_x = F.pad(flat_padded_x, (0, pad_size)) + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format) + + inp = ret.view(torch.uint8) + input_chunks = list(torch.chunk(inp, world_size, dim=0)) + output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0)) + dist.all_to_all(output_chunks, input_chunks, group=group) + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + dist.all_gather(scale_list, scale, group=group) + summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + + for scale, out in zip(scale_list, output_chunks): + out = out.view(fp8_type) + summed_out += cast_from_fp8(out, scale, input_type) + + if op == ReduceOp.AVG: + summed_out.div_(world_size) + + summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) + gather_scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) + + tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)] + gather_tensor_handle = dist.all_gather( + tensor_list, summed_out_fp8.view(torch.uint8), group=group, async_op=async_op + ) + + def cat_op(): + for i in range(world_size): + tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] + out = torch.cat(tensor_list, dim=0) + tensor.copy_(out[:input_size].view(input_shape).to(input_type)) + + if async_op: + return Handle([gather_scale_handle, gather_tensor_handle], cat_op) + else: + cat_op() + + +def all_to_all_single_fp8( + output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False +) -> Optional[Handle]: + r""" + This is an in-place operation for compressed all_reduce using fp8. + It works like dist.all_to_all_single but during communication the data is cast to fp8 format. + Args: + tensor: torch.Tensor in fp32, fp16, bf16 datatype. + fp8_format: e4m3 or e5m2 + Returns: + None + """ + world_size = dist.get_world_size(group=group) + input_type = input.dtype + input_shape = input.shape + input_device = input.device + input = input.flatten() + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + + ret, scale = cast_to_fp8(input, fp8_format=fp8_format) + + inp = ret.view(torch.uint8) + if input_split_sizes is not None: + input_split_sizes = [input_split_sizes[i] * np.prod(input_shape[1:]) for i in range(world_size)] + input_chunks = list(torch.split(inp, input_split_sizes)) + else: + input_chunks = list(torch.chunk(inp, world_size, dim=0)) + + if output_split_sizes is not None: + output_chunks = [ + torch.empty((output_split_sizes[i] * np.prod(input_shape[1:]),), device=input_device, dtype=inp.dtype) + for i in range(world_size) + ] + else: + if dist.get_rank() == world_size - 1: + output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)] + else: + output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] + + chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op) + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) + + def cast_op(): + cast_output_chunk = [ + cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks) + ] + + tensor_out = torch.cat(cast_output_chunk, dim=0) + outputs_shape = list(input_shape) + if output_split_sizes is not None: + outputs_shape[0] = sum(output_split_sizes) + else: + outputs_shape = input_shape + output.data = tensor_out.view(outputs_shape).to(input_type) + + if async_op: + return Handle([chunk_handle, scale_hanle], cast_op) + else: + cast_op() + + +def cast_to_fp8_pipeline(inp: Any) -> None: + """ + Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline. + The activations tensor is indexed by 'hidden_states' in the inp dict. + After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved. + Metadata such as fp8_scale is saved into inp dict for communication. + """ + if inp is None: + return + # In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted. + if type(inp) == torch.Tensor: + return + + assert "hidden_states" in inp, "required by pipeline parallelism." + assert ( + inp["hidden_states"].size(-1) % 2 == 0 + ), "tensor size(-1) must be divisible by 2 to view Float8_e4m3fn as BFloat16 or Float16" + inp_tensor = inp["hidden_states"] + inp_dtype = inp_tensor.dtype + + min_val, max_val = inp_tensor.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()) + + finfo = torch.finfo(torch.float8_e4m3fn) + if amax > finfo.max: + fp8_type = torch.float8_e5m2 + fp8_view_type = torch.float16 + else: + fp8_type = torch.float8_e4m3fn + fp8_view_type = torch.bfloat16 + + finfo = torch.finfo(fp8_type) + scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float() + q_tensor = inp_tensor.data.float() * scale + # Todo: Currently we use fp8_view_type to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'. + # inp_tensor needs to be a float datatype to avoid error during gradient placement. + inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type) + + inp["fp8_scale"] = scale.float().reciprocal() + inp["dtype"] = torch.zeros_like(scale).to(inp_dtype) + + +def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: + """ + Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline. + del_metadata = False is useful when this function is called before p2p communication. + """ + if inp is None: + return + if type(inp) == torch.Tensor: + return + + assert "hidden_states" in inp, "required by pipeline parallelism." + inp_tensor = inp["hidden_states"] + scale = inp["fp8_scale"] + + fp8_view_type = inp_tensor.dtype + if fp8_view_type == torch.float16: + fp8_type = torch.float8_e5m2 + elif fp8_view_type == torch.bfloat16: + fp8_type = torch.float8_e4m3fn + else: + raise TypeError("Only float16, bfloat16 are implemented.") + + inp_tensor.data = inp_tensor.data.view(fp8_type).to(inp["dtype"]) * scale + + if del_metadata: + del inp["fp8_scale"] + del inp["dtype"] + + +def reduce_scatter_fp8( + output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False +) -> Optional[Handle]: + r""" + This is an in-place operation for compressed reduce_scatter using fp8. + It works like dist.reduce_scatter but during communication the data is cast to fp8 format. + + Args: + tensor: torch.Tensor in fp32, fp16, bf16 datatype. + fp8_format: e4m3 or e5m2 + + Returns: + None + """ + + input_type = output.dtype + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + scale_list = [] + cast_input_list = [] + output_chunks = [] + output_scale_list = [] + for input in input_list: + ret, scale = cast_to_fp8(input, fp8_format=fp8_format) + scale_list.append(scale) + ret = ret.view(torch.uint8) + cast_input_list.append(ret) + output_chunks.append(torch.empty_like(ret)) + output_scale_list.append(torch.empty_like(scale)) + chunk_handle = dist.all_to_all(output_chunks, cast_input_list, group=group, async_op=async_op) + scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op) + + def cast_op(): + summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + for scale, out in zip(output_scale_list, output_chunks): + out = out.view(fp8_type) + summed_out += cast_from_fp8(out, scale, input_type) + output.data = summed_out + + if async_op: + return Handle([chunk_handle, scale_handle], cast_op) + else: + cast_op() + + +def fp8_compress_ddp_grad_comm_hook_async( + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, + fp8_format: str = "e5m2", +) -> torch.futures.Future[torch.Tensor]: + """ + Compress by casting ``GradBucket`` to FP8 floating-point format divided by process group size. + + This DDP communication hook implements a simple gradient compression approach that casts ``GradBucket`` tensor + to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then divides it + by the process group size. + Once compressed gradient tensors are allreduced, the chained callback ``decompress`` casts it back + to the input data type (such as ``float32``). + + Example:: + >>> ddp_model.register_comm_hook(process_group, fp8_compress_ddp_grad_comm_hook_async) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + + input_tensor = bucket.buffer() + world_size = dist.get_world_size() + input_type = input_tensor.dtype + input_device = input_tensor.device + flat_padded_x = input_tensor.flatten() + + if flat_padded_x.size(0) % world_size != 0: + pad_size = world_size - flat_padded_x.size(0) % world_size + flat_padded_x = F.pad(flat_padded_x, (0, pad_size)) + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format) + + inp = ret.view(torch.uint8) + output_chunks_single = torch.empty_like(inp) + split_sizes = [inp.numel() // world_size for _ in range(world_size)] + fut0 = dist.all_to_all_single( + output_chunks_single, + inp, + output_split_sizes=split_sizes, + input_split_sizes=split_sizes, + group=group_to_use, + async_op=True, + ).get_future() + + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + fut1 = dist.all_gather_into_tensor( + torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True + ).get_future() + all_to_all_fut = torch.futures.collect_all([fut0, fut1]) + + def sum_and_allgather(fut): + output_chunks_single = fut.value()[0].wait()[0] + scale_list_single = fut.value()[1].wait()[0] + + output_chunks = list(torch.chunk(output_chunks_single, world_size, dim=0)) + scale_list = scale_list_single.chunk(world_size, dim=0) + + summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + for scale, out in zip(scale_list, output_chunks): + out = out.view(fp8_type) + summed_out += cast_from_fp8(out, scale, input_type) + summed_out.div_(world_size) + + summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) + + tensor_list_single = torch.empty(summed_out_fp8.size(0) * world_size, device=input_device, dtype=torch.uint8) + fut2 = dist.all_gather_into_tensor( + tensor_list_single, summed_out_fp8.view(torch.uint8), group=group_to_use, async_op=True + ).get_future() + + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + fut3 = dist.all_gather_into_tensor( + torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True + ).get_future() + fut_combined2 = torch.futures.collect_all([fut2, fut3]) + return fut_combined2 + + def decompress(fut): + tensor_list_single = fut.value().wait()[0].value()[0] + scale_list_single = fut.value().wait()[1].value()[0] + + tensor_list = list(torch.chunk(tensor_list_single, world_size, dim=0)) + scale_list = scale_list_single.chunk(world_size, dim=0) + + for i in range(world_size): + tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] + out = torch.cat(tensor_list, dim=0) + + input_tensor_size = input_tensor.numel() + input_shape = input_tensor.shape + out = out[:input_tensor_size] + + input_tensor.copy_(out.view(input_shape).to(input_type)) + return input_tensor + + return all_to_all_fut.then(sum_and_allgather).then(decompress) + + +def fp8_compress_ddp_grad_comm_hook_sync( + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, + fp8_format="e5m2", +) -> torch.futures.Future[torch.Tensor]: + """ + Return a future that wraps the input, after the input is allreduced. However, the allreduce commnunication is synchronized. + This breaks the overlapping between allreduce communication and backward compuation. + + This hook should **only** be used for debugging purposes, instead of the normal gradient synchronization. + For asynchronized implementation, use fp8_compress_ddp_grad_comm_hook_async instead. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_sync) + """ + + buffer = bucket.buffer() + all_reduce_fp8(buffer, fp8_format=fp8_format) + + fut: torch.futures.Future[torch.Tensor] = torch.futures.Future() + fut.set_result(bucket.buffer()) + + return fut + + +def fp8_compress_fsdp_grad_comm_hook( + state: object, + unsharded_gradient_flattened: torch.Tensor, + sharded_gradient: torch.Tensor, + group=None, + fp8_format="e5m2", +) -> None: + """ + This communication hook implements a simple gradient compression approach that casts unsharded_gradient_flattened tensor + to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then perform scatter_allreduce logic + by using all_to_all and all_gather among the process group. + + Example:: + >>> fsdp_model.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook) + """ + grad = unsharded_gradient_flattened + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + input_type = grad.dtype + input_device = grad.device + world_size = dist.get_world_size(group=group) + + grad_fp8, scale = cast_to_fp8(grad, fp8_format=fp8_format) + uint8_buffer = torch.empty_like(grad_fp8).view(torch.uint8) + dist.all_to_all_single(uint8_buffer, grad_fp8.view(torch.uint8), group=group) + + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + dist.all_gather(scale_list, scale, group=group) + + buffer_list = list(torch.chunk(uint8_buffer.view(fp8_type), world_size, dim=0)) + sharded_gradient.zero_() + for tensor, scale in zip(buffer_list, scale_list): + sharded_gradient += cast_from_fp8(tensor, scale, input_type) + + +def fp8_compress_fsdp_params_comm_hook( + state: object, + padded_unsharded_flat_param: torch.Tensor, + sharded_flat_param: torch.Tensor, + group=None, + fp8_format="e5m2", +) -> None: + """ + This hook is pending the official support for parameters communication hook in FSDP, e.g. register_params_comm_hook. + + Example:: + >>> fsdp_model.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook) + """ + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + fp8_max = torch.finfo(fp8_type).max + inp = sharded_flat_param + out = padded_unsharded_flat_param + + per_tensor_max = inp.abs().max().float() + per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) + dist.all_reduce(per_tensor_max, op=torch.distributed.ReduceOp.MAX, group=group) + + scale = fp8_max / per_tensor_max + fp8_sharded_flat_param = (scale * inp.float()).to(fp8_type).view(torch.uint8) + + fp8_out = torch.empty(out.shape, dtype=torch.uint8, device=out.device) + dist.all_gather_into_tensor( + fp8_out, + fp8_sharded_flat_param, + group=group, + ) + padded_unsharded_flat_param.copy_((fp8_out.view(fp8_type).float() / scale).to(out.dtype)) + + +def split_chunk_by_channel( + chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1 +): + offset = chunk.numel() * rank + end = offset + chunk.numel() + break_points = [x for x in range(0, channel_size * num_channels + 1, channel_size) if offset <= x <= end] + if len(break_points) == 0 or break_points[0] > offset: + break_points.insert(0, offset) + if break_points[-1] < end: + break_points.append(end) + sizes = [b - a for a, b in zip(break_points[:-1], break_points[1:])] + return chunk.split(sizes) + + +def all_gather_into_tensor_flat_fp8( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + output_shape: torch.Size, + group: dist.ProcessGroup, + fp8_format: str = "e4m3", + async_op: bool = False, +) -> Optional[Handle]: + """all gather into tensor in fp8 format + + Args: + output_tensor (torch.Tensor): output tensor, which is flattened + input_tensor (torch.Tensor): input tensor, which is flattened + group (dist.ProcessGroup): process group + fp8_format (str, optional): fp8 format, e4m3 or e5m2. Defaults to "e4m3". + """ + assert input_tensor.dim() == 1 and output_tensor.dim() == 1, "input/output tensor should be flattened" + world_size = dist.get_world_size(group) + assert ( + output_tensor.numel() == input_tensor.numel() * world_size + ), "output tensor size should be world_size times of input tensor size" + + input_type = output_tensor.dtype + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + fp8_max = torch.finfo(fp8_type).max + + if len(output_shape) == 2: + per_channel_max = torch.zeros(output_shape[0], device=output_tensor.device, dtype=torch.float) + num_channels, channel_size = output_shape + rank = dist.get_rank(group) + channel_start_idx = (input_tensor.numel() * rank) // channel_size + per_channel_splits = split_chunk_by_channel(input_tensor, channel_size, num_channels, rank, world_size) + for i, per_channel_split in enumerate(per_channel_splits): + idx = i + channel_start_idx + if idx < num_channels: + per_channel_max[idx] = per_channel_split.abs().max().float() + dist.all_reduce(per_channel_max, op=dist.ReduceOp.MAX, group=group) + per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) + scale = fp8_max / per_channel_max + fp8_input = input_tensor.float() + fp8_per_channel_splits = split_chunk_by_channel(fp8_input, channel_size, num_channels, rank, world_size) + for i, per_channel_split in enumerate(fp8_per_channel_splits): + idx = i + channel_start_idx + if idx < num_channels: + per_channel_split.mul_(scale[idx]) + fp8_input = fp8_input.to(fp8_type) + else: + per_tensor_max = input_tensor.abs().max().float() + dist.all_reduce(per_tensor_max, op=dist.ReduceOp.MAX, group=group) + per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) + scale = fp8_max / per_tensor_max + fp8_input = (scale * input_tensor.float()).to(fp8_type) + scale_inv = 1.0 / scale + + buffer = torch.empty_like(output_tensor, dtype=fp8_type) + tensor_handle = dist.all_gather_into_tensor( + buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group, async_op=async_op + ) + + def cast_op(): + numel = output_shape.numel() + valid_buffer = buffer[:numel].reshape(output_shape) + valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2)) + output_tensor[:numel].copy_(valid_buffer.view(-1)) + + if async_op: + return Handle([tensor_handle], cast_op) + else: + cast_op() + + +def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False): + + world_size = dist.get_world_size(group) + + input_type = input_list[0].dtype + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + scale_list = [] + tensor_list = [] + + for i in range(world_size): + input_tensor = input_list[i] + ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format) + scale_list.append(scale) + ret = ret.view(torch.uint8) + tensor_list.append(ret) + + output_scale_list = [torch.empty_like(x) for x in scale_list] + output_tensor_list = [torch.empty_like(x) for x in tensor_list] + tensor_hanle = dist.all_to_all(output_tensor_list, tensor_list, group=group, async_op=async_op) + scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op) + + def cast_op(): + for i in range(world_size): + scale = output_scale_list[i] + tensor = output_tensor_list[i] + tensor = tensor.view(fp8_type) + output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) + + if async_op: + return Handle([tensor_hanle, scale_handle], cast_op) + else: + cast_op() + + +def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: + + world_size = dist.get_world_size(group) + + input_type = input_.dtype + ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) + fp8_type = ret.dtype + input_ = ret.view(torch.uint8) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] + chunk_handle = dist.all_gather(tensor_list, input_, group=group, async_op=async_op) + scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) + + def cast_op(): + for i in range(world_size): + output = tensor_list[i].view(fp8_type) + scale = scale_list[i] + output_list[i].copy_(cast_from_fp8(output, scale, input_type)) + + if async_op: + return Handle([chunk_handle, scale_hanle], cast_op) + else: + cast_op() + + +class _LinearFp8(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + x: torch.Tensor, + w: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> Any: + assert ( + x.dtype in (torch.bfloat16, torch.float16) and x.dtype == w.dtype + ), "Only float16 and bfloat16 are allowed." + if bias is not None: + assert bias.dtype == x.dtype, "Bias should have the same dtype as input." + # ensure x and w are row-major + x = x.contiguous() + w = w.contiguous() + ctx.x_shape = x.shape + ctx.has_bias = bias is not None + ctx.out_dtype = x.dtype + x = x.reshape(-1, x.shape[-1]) + + x_fp8, inv_scale_x = cast_to_fp8(x, fp8_format="e4m3") + w_fp8, inv_scale_w = cast_to_fp8(w, fp8_format="e4m3") + ctx.x_fp8 = x_fp8 + ctx.w_fp8_t = w_fp8.t() + ctx.inv_scale_x = inv_scale_x + ctx.inv_scale_w = inv_scale_w + out = torch._scaled_mm( + x_fp8, + ctx.w_fp8_t, + bias=bias, + out_dtype=ctx.out_dtype, + scale_a=inv_scale_x, + scale_b=inv_scale_w, + use_fast_accum=True, + )[0] + return out.reshape(*ctx.x_shape[:-1], w.shape[0]) + + @staticmethod + def backward(ctx: Any, out_grad) -> Any: + out_grad = out_grad.reshape(-1, out_grad.shape[-1]) + out_grad_fp8, out_grad_scale = cast_to_fp8(out_grad, fp8_format="e5m2") + x_grad = torch._scaled_mm( + out_grad_fp8, + ctx.w_fp8_t.contiguous().t(), + out_dtype=ctx.out_dtype, + scale_a=out_grad_scale, + scale_b=ctx.inv_scale_w, + use_fast_accum=True, + )[0] + w_grad = torch._scaled_mm( + out_grad_fp8.t().contiguous(), + ctx.x_fp8.t().contiguous().t(), + out_dtype=ctx.out_dtype, + scale_a=out_grad_scale, + scale_b=ctx.inv_scale_x, + use_fast_accum=True, + )[0] + bias_grad = None + if ctx.has_bias: + bias_grad = out_grad.sum(0) + return x_grad.reshape(ctx.x_shape), w_grad, bias_grad + + +@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False) +def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _LinearFp8.apply(input, weight, bias) + + +def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + out = _linear_fp8(input, weight, bias) + return out diff --git a/colossalai/quantization/fp8_hook.py b/colossalai/quantization/fp8_hook.py new file mode 100644 index 000000000..6171dd755 --- /dev/null +++ b/colossalai/quantization/fp8_hook.py @@ -0,0 +1,23 @@ +import torch.nn.functional as F + +from colossalai.quantization.fp8 import linear_fp8 +from colossalai.tensor.param_op_hook import ColoParamOpHook + + +class FP8Hook(ColoParamOpHook): + def pre_forward(self, params) -> None: + pass + + def post_forward(self, params) -> None: + pass + + def pre_backward(self, params) -> None: + pass + + def post_backward(self, params) -> None: + pass + + def rewrite_op(self, func): + if func is F.linear: + return linear_fp8 + return func diff --git a/colossalai/quantization/utils.py b/colossalai/quantization/utils.py new file mode 100644 index 000000000..5b1e11c9f --- /dev/null +++ b/colossalai/quantization/utils.py @@ -0,0 +1,112 @@ +import torch +import torch.distributed as dist +from packaging import version +from torch import Tensor +from torch.distributed.fsdp._common_utils import _no_dispatch_record_stream +from torch.distributed.utils import _p_assert + + +def _all_gather_flat_param( + self, + padded_unsharded_flat_param: Tensor, +) -> Tensor: + """ + All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``. + + Then switch to use the all-gathered tensor. + """ + _p_assert( + hasattr(self, "process_group") and hasattr(self, "world_size"), + "Expects a process group and world size to have been set via `shard()`", + ) + sharded_flat_param = self.flat_param.data + expected_numel = sharded_flat_param.numel() * self.world_size + _p_assert( + padded_unsharded_flat_param.numel() == expected_numel, + f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}", + ) + + pg = self._fake_process_group if self._use_fake_all_gather else self.process_group + + # HACK this should be handled by C10D + if sharded_flat_param.is_cpu: # type: ignore[attr-defined] + tensor_list = list(torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg))) + work = dist.all_gather(tensor_list, sharded_flat_param, group=pg) + else: + if self._comm_hook is None: + dist.all_gather_into_tensor( + padded_unsharded_flat_param, + sharded_flat_param, + pg, + ) + else: + self._comm_hook(None, padded_unsharded_flat_param, sharded_flat_param, pg) + + if self._offload_params: + # In case of offloading, `flat_param.data` (i.e. sharded param) is + # created on the pre-unshard stream. We need to hand it over to the + # unshard stream for all-gather + _no_dispatch_record_stream( + sharded_flat_param, + self._device_handle.current_stream(), # unshard_stream + ) + return padded_unsharded_flat_param + + +def register_params_comm_hook(self, state: object, hook: callable): + """Register a communication hook for FlatParamHandle. + + This is an enhancement that provides a flexible hook to users where they can specify how FSDP unshards + parameters across multiple workers. + + .. warning :: + FSDP communication hook should be registered before running an initial forward pass + and only once. + + Args: + state (object): Passed to the hook to maintain any state information during the training process. + hook (Callable): Callable, which has one of the following signatures: + 1) ``hook: Callable[torch.Tensor] -> None``: + This function takes in a Python tensor, which represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). + It then performs all necessary processing and returns ``None``; + 2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``: + This function takes in two Python tensors, the first one represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). The latter + represents a pre-sized tensor to store a chunk of a sharded gradient after + reduction. + In both cases, callable performs all necessary processing and returns ``None``. + Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case. + Callables with signature 2 are expected to handle gradient communication for sharded cases. + + """ + if not self.check_is_root(): + raise AssertionError("register_comm_hook can only be called on a root instance.") + + # if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: + # raise AssertionError( + # f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}" + # ) + if self._handle._comm_hook is not None: + raise AssertionError("A communication hook is already registered") + if not callable(hook): + raise ValueError(f"The communication hook must be callable but got {hook}") + self._handle._comm_hook = hook + self._handle._comm_hook_state = state + + +def patch_fsdp_params_comm_hook(): + if version.parse(torch.__version__) >= version.parse("2.2.0"): + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp._flat_param import FlatParamHandle + + FlatParamHandle._comm_hook = None + FlatParamHandle._comm_hook_state = None + FlatParamHandle._all_gather_flat_param = _all_gather_flat_param + FSDP.register_params_comm_hook = register_params_comm_hook + else: + raise RuntimeError("This fsdp_params_comm_hook patch is not supported while torch version under 2.2.0.") diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 25983e0a9..f58a21df3 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -16,6 +16,14 @@ try: except ImportError: _grad_accum_fusion_available = False +from colossalai.quantization.fp8 import ( + all_reduce_fp8, + all_to_all_fp8, + all_to_all_single_fp8, + gather_fp8, + reduce_scatter_fp8, +) + class FusedLayerNormAffineFunction1D(torch.autograd.Function): r"""Layernorm @@ -61,11 +69,12 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce + ctx.fp8_communication = fp8_communication output = torch.matmul(input_, weight) @@ -78,6 +87,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias + fp8_communication = ctx.fp8_communication # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. weight = weight.view(weight.shape) @@ -92,7 +102,9 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1]) - if ctx.async_grad_allreduce: + if ctx.async_grad_allreduce and fp8_communication: + _reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2") + elif ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have @@ -101,10 +113,10 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None - if ctx.async_grad_allreduce: + if ctx.async_grad_allreduce and not fp8_communication: handle.wait() - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None class LinearWithAsyncCommunication(torch.autograd.Function): @@ -113,11 +125,12 @@ class LinearWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce + ctx.fp8_communication = fp8_communication if bias is not None: output = F.linear(input_, weight, bias) else: @@ -129,6 +142,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function): def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias + fp8_communication = ctx.fp8_communication # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. if use_bias: @@ -144,10 +158,11 @@ class LinearWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: # Asynchronous all-reduce - handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) - _ = torch.zeros(1, device=grad_input.device) - - # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + if fp8_communication: + all_reduce_fp8(grad_input, group=ctx.process_group) + else: + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py if _grad_accum_fusion_available and weight.grad is not None: @@ -165,10 +180,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function): grad_bias = grad_output.sum(dim=0) if use_bias else None - if ctx.async_grad_allreduce: + if ctx.async_grad_allreduce and not fp8_communication: handle.wait() - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): @@ -236,17 +251,18 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, dim): + def forward(ctx, input_, process_group, dim, fp8_communication=False): ctx.process_group = process_group ctx.dim = dim + ctx.fp8_communication = fp8_communication - return _gather(input_, dim, process_group) + return _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3") @staticmethod def backward(ctx, grad_output): dim = ctx.dim process_group = ctx.process_group - + fp8_communication = ctx.fp8_communication # do reduce-scatter new_shape = list(grad_output.shape) assert ( @@ -257,9 +273,13 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function): item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim) ] output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) - dist.reduce_scatter(output, grad_list, group=process_group) - return output, None, None + if fp8_communication: + reduce_scatter_fp8(output, grad_list, group=process_group, fp8_format="e5m2") + else: + dist.reduce_scatter(output, grad_list, group=process_group) + + return output, None, None, None class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): @@ -550,9 +570,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, dim): + def forward(ctx, input_, process_group, dim, fp8_communication=False): ctx.dim = dim ctx.process_group = process_group + ctx.fp8_communication = fp8_communication # do reduce-scatter new_shape = list(input_.shape) @@ -562,7 +583,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function): new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) - dist.reduce_scatter(output, input_list, group=process_group) + if fp8_communication: + reduce_scatter_fp8(output, input_list, group=process_group, fp8_format="e4m3") + else: + dist.reduce_scatter(output, input_list, group=process_group) return output @@ -570,8 +594,9 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function): def backward(ctx, grad_output): dim = ctx.dim process_group = ctx.process_group + fp8_communication = ctx.fp8_communication - return _gather(grad_output, dim, process_group), None, None + return _gather(grad_output, dim, process_group, fp8_communication, fp8_format="e5m2"), None, None, None class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): @@ -586,13 +611,16 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring): + def forward( + ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication + ): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim ctx.overlap = overlap + ctx.fp8_communication = fp8_communication if ring is True: input_to_gather = {} @@ -609,7 +637,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): ) else: - input_parallel = _gather(input_, dim, process_group) + input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3") output = torch.matmul(input_parallel, weight) @@ -624,6 +652,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): dim = ctx.dim process_group = ctx.process_group overlap = ctx.overlap + fp8_communication = ctx.fp8_communication # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm weight = weight.view(weight.shape) @@ -631,7 +660,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): bias = bias.view(bias.shape) if not overlap: - input_parallel = _gather(input_, dim, process_group) + input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2") total_input = input_parallel grad_input = grad_output.matmul(weight.T) @@ -691,7 +720,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -706,17 +735,25 @@ class _SplitForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group, grad_scale=None): + def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False): ctx.process_group = process_group ctx.dim = dim ctx.grad_scale = grad_scale + ctx.fp8_communication = fp8_communication return _split(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): if ctx.grad_scale is not None: grad_output = grad_output * ctx.grad_scale - return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None + + return ( + _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, fp8_format="e5m2"), + None, + None, + None, + None, + ) class _ReduceForward(torch.autograd.Function): @@ -730,15 +767,15 @@ class _ReduceForward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, grad_scale=None): + def forward(ctx, input_, process_group, grad_scale=None, fp8_communication=False): ctx.grad_scale = grad_scale - return _reduce(input_, process_group) + return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3") @staticmethod def backward(ctx, grad_output): if ctx.grad_scale is not None: grad_output = grad_output * ctx.grad_scale - return grad_output, None, None + return grad_output, None, None, None class _ReduceBackward(torch.autograd.Function): @@ -751,13 +788,15 @@ class _ReduceBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group): + def forward(ctx, input_, process_group, fp8_communication=False): ctx.process_group = process_group + ctx.fp8_communication = fp8_communication return input_ @staticmethod def backward(ctx, grad_output): - return _reduce(grad_output, ctx.process_group), None + fp8_communication = ctx.fp8_communication + return _reduce(grad_output, ctx.process_group, fp8_communication, fp8_format="e5m2"), None, None class _GatherForwardSplitBackward(torch.autograd.Function): @@ -770,17 +809,18 @@ class _GatherForwardSplitBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group, grad_scale=None): + def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False): ctx.process_group = process_group ctx.dim = dim ctx.grad_scale = grad_scale - return _gather(input_, dim, process_group) + + return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3") @staticmethod def backward(ctx, grad_output): if ctx.grad_scale is not None: grad_output = grad_output * ctx.grad_scale - return _split(grad_output, ctx.dim, ctx.process_group), None, None, None + return _split(grad_output, ctx.dim, ctx.process_group), None, None, None, None class _AllToAll(torch.autograd.Function): @@ -794,26 +834,67 @@ class _AllToAll(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, scatter_dim, gather_dim): + def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication=False): ctx.process_group = process_group ctx.scatter_dim = scatter_dim ctx.gather_dim = gather_dim + ctx.fp8_communication = fp8_communication world_size = dist.get_world_size(process_group) bsz, _, _ = input_.shape # using all_to_all_single when batch size is 1 if bsz == 1: - return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim) + return _all_to_all_single( + input_, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e4m3", + ) else: - return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) + return _all_to_all( + input_, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e4m3", + ) @staticmethod - def backward(ctx, *grad_output): + def backward(ctx, grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim - return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) - return (return_grad, None, None, None) + fp8_communication = ctx.fp8_communication + world_size = dist.get_world_size(process_group) + bsz, _, _ = grad_output.shape + + if bsz == 1: + return_grad = _all_to_all_single( + grad_output, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e5m2", + ) + else: + return_grad = _all_to_all( + grad_output, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e5m2", + ) + + return (return_grad, None, None, None, None) class HookParameter(torch.autograd.Function): @@ -839,12 +920,15 @@ def hook_parameter_in_backward(input, weight=None, bias=None): return HookParameter.apply(input, weight, bias) -def _reduce(input_, process_group): +def _reduce(input_, process_group, fp8_communication=False, fp8_format="e5m2"): # skip if only one rank involved if dist.get_world_size(process_group) == 1: return input_ else: - dist.all_reduce(input_, group=process_group) + if fp8_communication: + all_reduce_fp8(input_, group=process_group, fp8_format=fp8_format) + else: + dist.all_reduce(input_, group=process_group) return input_ @@ -868,18 +952,19 @@ def _split(input_, dim=-1, process_group=None): return output -def _gather(input_, dim=-1, process_group=None): +def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"): # skip if only one rank involved world_size = dist.get_world_size(process_group) if world_size == 1: return input_ - # all gather input_ = input_.contiguous() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - torch.distributed.all_gather(tensor_list, input_, group=process_group) + if fp8_communication: + gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group) + else: + dist.all_gather(tensor_list, input_, group=process_group) - # concat output = torch.cat(tensor_list, dim=dim).contiguous() return output @@ -909,14 +994,19 @@ def _reduce_scatter(input_, dim=1, process_group=None): return output -def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): +def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"): input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] - dist.all_to_all(output_list, input_list, group=group) + if fp8_communication: + all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format) + else: + dist.all_to_all(output_list, input_list, group=group) return torch.cat(output_list, dim=gather_dim).contiguous() -def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): +def _all_to_all_single( + input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2" +): inp_shape = list(input_.shape) inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size if scatter_dim < 2: @@ -929,7 +1019,11 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): ) output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=group) + if fp8_communication: + all_to_all_single_fp8(output, input_t, group=group, fp8_format=fp8_format) + else: + + dist.all_to_all_single(output, input_t, group=group) if scatter_dim < 2: output = output.transpose(0, 1).contiguous() @@ -943,12 +1037,16 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): ).contiguous() -def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): - return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) +def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + return MatmulWithAsyncCommunication.apply( + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication + ) -def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): - return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) +def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + return LinearWithAsyncCommunication.apply( + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication + ) def linear_gather_forward_reducescatter_backward( @@ -959,12 +1057,12 @@ def linear_gather_forward_reducescatter_backward( ) -def gather_forward_reducescatter_backward(input_, process_group, dim): - return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim) +def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False): + return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication) -def reducescatter_forward_gather_backward(input_, process_group, dim): - return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim) +def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False): + return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication) def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False): @@ -972,38 +1070,40 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False ): return _MatmulWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication ) -def gather_forward_split_backward(input_, dim, process_group, grad_scale=None): - return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale) +def gather_forward_split_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False): + return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) -def split_forward_gather_backward(input_, dim, process_group, grad_scale=None): - return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale) +def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False): + return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) -def reduce_forward(input_, process_group, grad_scale=None): - return _ReduceForward.apply(input_, process_group, grad_scale) +def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False): + return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication) -def reduce_backward(input_, process_group): - return _ReduceBackward.apply(input_, process_group) +def reduce_backward(input_, process_group, fp8_communication=False): + return _ReduceBackward.apply(input_, process_group, fp8_communication) -def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): - return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) +def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) -def gather_sp_output(hidden_states, sp_group, sp_mode): +def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False): """ Gather the output of the last layer for cross entropy computation """ # Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group) scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group) - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication + ) return hidden_states diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 9b77774aa..18efb0ec5 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -68,6 +68,7 @@ class Embedding1D(ParallelModule): gather_output: bool = True, weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), + fp8_communication: bool = False, *args, **kwargs, ): @@ -81,6 +82,7 @@ class Embedding1D(ParallelModule): self.embed_args = args self.embed_kwargs = kwargs self.gather_output = gather_output + self.fp8_communication = fp8_communication # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -155,7 +157,9 @@ class Embedding1D(ParallelModule): def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) if self.gather_output: - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + output = gather_forward_split_backward( + output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) return output else: return output_parallel @@ -274,6 +278,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule): weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), make_vocab_size_divisible_by: int = 64, + fp8_communication: bool = False, *args, **kwargs, ): @@ -282,6 +287,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule): self.embed_args = args self.embed_kwargs = kwargs self.process_group = process_group + self.fp8_communication = fp8_communication tensor_parallel_size = dist.get_world_size(group=process_group) tensor_parallel_rank = dist.get_rank(group=process_group) @@ -390,5 +396,5 @@ class VocabParallelEmbedding1D(PaddingParallelModule): embedding_output = output_parallel.clone() embedding_output[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. - output = reduce_forward(embedding_output, self.process_group) + output = reduce_forward(embedding_output, self.process_group, fp8_communication=self.fp8_communication) return output diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 020e793af..d77dd4965 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -84,6 +84,7 @@ class Linear1D_Col(ParallelModule): bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, **kwargs, ): super().__init__(weight=weight, bias_=bias_, **kwargs) @@ -98,6 +99,7 @@ class Linear1D_Col(ParallelModule): self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -202,19 +204,25 @@ class Linear1D_Col(ParallelModule): if self.seq_parallel_mode == "split_gather": input_parallel = gather_forward_reducescatter_backward( - input_parallel, self.process_group, self.seq_parallel_dim + input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication + ) + output_parallel = linear_with_async_comm( + input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication ) - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) elif self.seq_parallel_mode == "ring": output_parallel = linear_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True ) else: - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + output_parallel = linear_with_async_comm( + input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication + ) if self.gather_output: # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + output = gather_forward_split_backward( + output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) else: output = output_parallel @@ -264,6 +272,7 @@ class Linear1D_Row(ParallelModule): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1, + fp8_communication: bool = False, ): super().__init__() @@ -278,6 +287,7 @@ class Linear1D_Row(ParallelModule): self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -398,7 +408,9 @@ class Linear1D_Row(ParallelModule): ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions ) - input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + input_ = split_forward_gather_backward( + input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) if self.stream_chunk_num > 1: if self.training: @@ -416,10 +428,13 @@ class Linear1D_Row(ParallelModule): handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - if self.seq_parallel_mode == "split_gather": + if self.seq_parallel_mode is None: + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) + output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) + elif self.seq_parallel_mode == "split_gather": output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reducescatter_forward_gather_backward( - output_parallel, self.process_group, self.seq_parallel_dim + output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication ) elif self.seq_parallel_mode == "ring": output = linear_reducescatter_forward_gather_backward( @@ -562,6 +577,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule): weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, make_vocab_size_divisible_by: int = 64, + fp8_communication: bool = False, **kwargs, ): # create weight and bias @@ -592,6 +608,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule): **kwargs, new_num_embeddings=new_out_features, old_num_embeddings=out_features, + fp8_communication=fp8_communication, ) # get the length of valid embeddings tp_rank = dist.get_rank(process_group) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 000934ad9..f9a41a467 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -183,6 +183,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, ): super().__init__() @@ -197,6 +198,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.n_fused = n_fused self.process_group = process_group self.async_communication = async_communication + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -314,14 +316,26 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): if self.seq_parallel_mode is None: # Set up backprop all-reduce. - input_parallel = reduce_backward(input_, self.process_group) + input_parallel = reduce_backward(input_, self.process_group, fp8_communication=self.fp8_communication) output_parallel = matmul_with_async_comm( - input_parallel, self.weight, bias, self.process_group, self.async_communication + input_parallel, + self.weight, + bias, + self.process_group, + self.async_communication, + fp8_communication=self.fp8_communication, ) elif self.seq_parallel_mode == "split_gather": input_parallel = input_ output_parallel = matmul_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap + input_parallel, + self.weight, + bias, + self.process_group, + True, + 1, + self.overlap, + fp8_communication=self.fp8_communication, ) elif self.seq_parallel_mode == "ring": input_parallel = input_ @@ -331,7 +345,9 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): if self.gather_output: # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + output = gather_forward_split_backward( + output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) else: output = output_parallel @@ -379,6 +395,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1, + fp8_communication: bool = False, ): super().__init__() @@ -392,6 +409,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): self.process_group = process_group self.seq_parallel_mode = seq_parallel_mode self.num_partitions = dist.get_world_size(self.process_group) + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -514,7 +532,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions ) - input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + input_ = split_forward_gather_backward( + input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) if self.stream_chunk_num > 1: if self.training: @@ -535,13 +555,20 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): else: if self.seq_parallel_mode is None: output_parallel = torch.matmul(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) elif self.seq_parallel_mode == "split_gather": output_parallel = torch.matmul(input_, self.weight) - output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + output = reducescatter_forward_gather_backward( + output_parallel, + self.process_group, + 1, + self.fp8_communication, + ) elif self.seq_parallel_mode == "ring": output_parallel = torch.matmul(input_, self.weight) - output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + output = reducescatter_forward_gather_backward( + output_parallel, self.process_group, 1, self.fp8_communication + ) if not self.skip_bias_add: if self.bias is not None: @@ -600,6 +627,7 @@ class FusedLinear1D_Col(ParallelModule): bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, ): super().__init__() # Keep input parameters @@ -611,6 +639,7 @@ class FusedLinear1D_Col(ParallelModule): self.n_fused = n_fused self.process_group = process_group self.async_communication = async_communication + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -740,7 +769,9 @@ class FusedLinear1D_Col(ParallelModule): if self.gather_output: # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + output = gather_forward_split_backward( + output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) else: output = output_parallel diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 7710b56e7..580f3618c 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -187,11 +187,17 @@ class BertPipelineForwards: if shard_config is not None and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) if encoder_hidden_states is not None: encoder_hidden_states = split_forward_gather_backward( - encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + encoder_hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): @@ -242,7 +248,10 @@ class BertPipelineForwards: if shard_config is not None and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) if output_hidden_states: @@ -1135,11 +1144,17 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] embedding_output = split_forward_gather_backward( - embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group + embedding_output, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) if encoder_hidden_states is not None: encoder_hidden_states = split_forward_gather_backward( - encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + encoder_hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) encoder_outputs = self.encoder( @@ -1159,7 +1174,10 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): # When sequence parallelism done, gather the output tensor in forward and split it in backward sequence_output = gather_forward_split_backward( - sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group + sequence_output, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 26ffef6c5..f8fd4665f 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -221,7 +221,10 @@ class BloomPipelineForwards: if shard_config and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) start_idx, end_idx = stage_index[0], stage_index[1] @@ -264,7 +267,10 @@ class BloomPipelineForwards: if shard_config and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) if stage_manager.is_last_stage(): @@ -922,7 +928,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -960,7 +969,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): # When sequence parallelism done, gather the output tensor in forward and split it in backward hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) # Add last hidden state hidden_states = self.ln_f(hidden_states) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 34d900d8d..a761968af 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -206,6 +206,15 @@ class ChatGLMPipelineForwards: hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, + ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=0, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=1 / shard_config.sequence_parallel_size, + fp8_communication=shard_config.fp8_communication, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = split_forward_gather_backward( @@ -245,6 +254,15 @@ class ChatGLMPipelineForwards: hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, + ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = gather_forward_split_backward( + hidden_states, + dim=0, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=shard_config.sequence_parallel_size, + fp8_communication=shard_config.fp8_communication, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = gather_forward_split_backward( @@ -401,6 +419,12 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, rotary_pos_emb = rotary_pos_emb[None, :seq_length] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + if sp_mode in ["all_to_all"] and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with sp mode `{sp_mode}`. Setting `use_cache=False`..." + ) + use_cache = False if sp_mode in ["all_to_all"] and self.training: if use_cache: logger.warning_once( @@ -414,6 +438,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, inputs_embeds, dim=0, process_group=sp_group, + fp8_communication=shard_config.fp8_communication, ) elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward( @@ -421,6 +446,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, dim=0, process_group=sp_group, grad_scale=1 / sp_size, + fp8_communication=shard_config.fp8_communication, ) hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, @@ -436,6 +462,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) elif sp_mode == "all_to_all": hidden_states = gather_forward_split_backward( @@ -443,6 +470,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, dim=0, process_group=sp_group, grad_scale=sp_size, + fp8_communication=shard_config.fp8_communication, ) if not return_dict: @@ -532,9 +560,24 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s key_layer = key_layer.reshape(sq, bs, -1) value_layer = value_layer.reshape(sq, bs, -1) - query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0) - key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0) - value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0) + query_layer = all_to_all_comm( + query_layer, + sp_group, + gather_dim=0, + fp8_communication=shard_config.fp8_communication, + ) + key_layer = all_to_all_comm( + key_layer, + sp_group, + gather_dim=0, + fp8_communication=shard_config.fp8_communication, + ) + value_layer = all_to_all_comm( + value_layer, + sp_group, + gather_dim=0, + fp8_communication=shard_config.fp8_communication, + ) query_layer = query_layer.view( sq * sp_size, @@ -610,7 +653,13 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) if sp_mode == "all_to_all": - context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0) + context_layer = all_to_all_comm( + context_layer, + sp_group, + gather_dim=2, + scatter_dim=0, + fp8_communication=shard_config.fp8_communication, + ) # ================= # Output. [sq, b, h] diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 67c20eed8..530338394 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -142,6 +142,7 @@ class CommandPipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = split_forward_gather_backward( @@ -149,6 +150,7 @@ class CommandPipelineForwards: dim=1, process_group=shard_config.sequence_parallel_process_group, grad_scale=1 / shard_config.sequence_parallel_size, + fp8_communication=shard_config.fp8_communication, ) # decoder layers @@ -213,6 +215,7 @@ class CommandPipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = gather_forward_split_backward( @@ -220,6 +223,7 @@ class CommandPipelineForwards: dim=1, process_group=shard_config.sequence_parallel_process_group, grad_scale=shard_config.sequence_parallel_size, + fp8_communication=shard_config.fp8_communication, ) # add hidden states from the last decoder layer @@ -384,9 +388,9 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -448,7 +452,9 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -528,9 +534,13 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) hidden_states = inputs_embeds # decoder layers @@ -575,9 +585,13 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 429c4350c..7ec390d6a 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -24,6 +24,7 @@ from colossalai.moe._operation import ( all_to_all_uneven, ) from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import all_reduce_fp8 from colossalai.shardformer.layer._operation import ( all_to_all_comm, gather_forward_split_backward, @@ -61,7 +62,13 @@ class EPDeepseekMoE(nn.Module): def __init__(self): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): + def setup_process_groups( + self, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + fp8_communication: bool = False, + ): assert tp_group is not None assert moe_dp_group is not None assert ep_group is not None @@ -70,6 +77,7 @@ class EPDeepseekMoE(nn.Module): self.ep_rank = dist.get_rank(ep_group) self.num_experts = self.config.n_routed_experts assert self.num_experts % self.ep_size == 0 + self.fp8_communication = fp8_communication self.ep_group = ep_group self.num_experts_per_ep = self.num_experts // self.ep_size @@ -86,9 +94,15 @@ class EPDeepseekMoE(nn.Module): self.tp_group = tp_group if self.tp_group.size() > 1: for expert in held_experts: - expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group) - expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group) - expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group) + expert.gate_proj = Linear1D_Col.from_native_module( + expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.up_proj = Linear1D_Col.from_native_module( + expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.down_proj = Linear1D_Row.from_native_module( + expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication + ) for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) @@ -106,7 +120,8 @@ class EPDeepseekMoE(nn.Module): if module.__class__.__name__ == "DeepseekMLP": return module module.__class__ = EPDeepseekMoE - module.setup_process_groups(tp_group, moe_dp_group, ep_group) + fp8_communication = kwargs.get("fp8_communication", False) + module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication=fp8_communication) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -130,18 +145,32 @@ class EPDeepseekMoE(nn.Module): output_split_sizes = torch.zeros_like(input_split_sizes) # [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3] - dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) + dist.all_to_all_single( + output_split_sizes, + input_split_sizes, + group=self.ep_group, + ) with torch.no_grad(): activate_experts = output_split_sizes[: self.num_experts_per_ep].clone() for i in range(1, self.ep_size): activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep] activate_experts = (activate_experts > 0).float() - dist.all_reduce(activate_experts, group=self.moe_dp_group) + + if self.fp8_communication: + all_reduce_fp8(activate_experts, group=self.moe_dp_group) + else: + dist.all_reduce(activate_experts, group=self.moe_dp_group) input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + output_states, _ = all_to_all_uneven( + dispatch_states, + input_split_list, + output_split_list, + self.ep_group, + fp8_communication=self.fp8_communication, + ) output_states = EPGradScalerIn.apply(output_states, self.ep_size) if output_states.size(0) > 0: @@ -167,7 +196,9 @@ class EPDeepseekMoE(nn.Module): output_states_list.append(split_states) output_states = torch.cat(output_states_list) output_states = EPGradScalerOut.apply(output_states, self.ep_size) - dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + dispatch_states, _ = all_to_all_uneven( + output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication + ) recover_token_idx = torch.empty_like(flat_topk_token_idx) recover_token_idx[flat_topk_token_idx] = torch.arange( flat_topk_token_idx.size(0), device=flat_topk_token_idx.device @@ -534,9 +565,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim @@ -595,7 +626,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) # (1, 4, 256) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -669,6 +702,7 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si # TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code. self._use_flash_attention_2 = shard_config.enable_flash_attention self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None @@ -688,9 +722,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si ) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) # embed positions hidden_states = inputs_embeds @@ -734,9 +772,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 6ecda91c4..97544e110 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -221,6 +221,7 @@ class GPT2PipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) # Going through held blocks. @@ -276,6 +277,7 @@ class GPT2PipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) if stage_manager.is_last_stage(): @@ -1119,6 +1121,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, dim=1, process_group=shard_config.sequence_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -1186,6 +1189,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, dim=1, process_group=shard_config.sequence_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) hidden_states = self.ln_f(hidden_states) diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index facd2fcaf..51b228712 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -185,6 +185,7 @@ class GPTJPipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) # Going through held blocks. @@ -236,6 +237,7 @@ class GPTJPipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) if stage_manager.is_last_stage(): @@ -915,6 +917,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -978,6 +981,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) hidden_states = self.ln_f(hidden_states) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index af610500a..4e8f67407 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -162,9 +162,13 @@ class LlamaPipelineForwards: hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) elif is_share_sp_tp(sp_mode): - hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) + hidden_states = split_forward_gather_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) + hidden_states = split_forward_gather_backward( + hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) if self.gradient_checkpointing and self.training and use_cache: if use_cache: @@ -227,7 +231,9 @@ class LlamaPipelineForwards: if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) + hidden_states = gather_sp_output( + hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: @@ -532,9 +538,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -605,7 +611,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -707,9 +715,13 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors elif is_share_sp_tp(sp_mode): - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) hidden_states = inputs_embeds # decoder layers @@ -754,7 +766,9 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N hidden_states = self.norm(hidden_states) # Cases that don't support parallelizing cross entropy computation along sequence if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather: - hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) + hidden_states = gather_sp_output( + hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index d30ce5ea8..4850ef1b6 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -53,7 +53,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): def __init__(self, *args, **kwargs): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): + def setup_process_groups( + self, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + fp8_communication: bool = False, + ): assert tp_group is not None assert moe_dp_group is not None assert ep_group is not None @@ -62,6 +68,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): self.ep_size = dist.get_world_size(ep_group) self.ep_rank = dist.get_rank(ep_group) self.ep_group = ep_group + self.fp8_communication = fp8_communication if self.num_experts % self.ep_size != 0: raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") @@ -80,9 +87,15 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): self.tp_group = tp_group if self.tp_group.size() > 1: for expert in held_experts: - expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group) - expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group) - expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group) + expert.w1 = Linear1D_Col.from_native_module( + expert.w1, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.w3 = Linear1D_Col.from_native_module( + expert.w3, self.tp_group, fp8_communication=self.fp8_communication + ) + expert.w2 = Linear1D_Row.from_native_module( + expert.w2, self.tp_group, fp8_communication=self.fp8_communication + ) for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) @@ -99,7 +112,8 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): # TODO: better init LazyInitContext.materialize(module) module.__class__ = EPMixtralSparseMoeBlock - module.setup_process_groups(tp_group, moe_dp_group, ep_group) + fp8_communication = kwargs.get("fp8_communication", False) + module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -120,6 +134,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): input_split_sizes = selected_experts.bincount(minlength=self.num_experts) output_split_sizes = torch.zeros_like(input_split_sizes) + dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) with torch.no_grad(): @@ -132,7 +147,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + output_states, _ = all_to_all_uneven( + dispatch_states, + input_split_list, + output_split_list, + self.ep_group, + fp8_communication=self.fp8_communication, + ) # compute expert output output_states = EPGradScalerIn.apply(output_states, self.ep_size) if output_states.size(0) > 0: @@ -162,7 +183,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): output_states = torch.cat(output_states_list) output_states = EPGradScalerOut.apply(output_states, self.ep_size) - dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + dispatch_states, _ = all_to_all_uneven( + output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication + ) recover_experts_idx = torch.empty_like(selected_experts_idx) recover_experts_idx[selected_experts_idx] = torch.arange( @@ -566,9 +589,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -673,7 +696,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) # (1, 4, 256) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -780,9 +805,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz ) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) hidden_states = inputs_embeds # decoder layers @@ -831,9 +860,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 538e96c32..a61360d17 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -175,6 +175,7 @@ class Qwen2PipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = split_forward_gather_backward( @@ -182,6 +183,7 @@ class Qwen2PipelineForwards: dim=1, process_group=shard_config.sequence_parallel_process_group, grad_scale=1 / shard_config.sequence_parallel_size, + fp8_communication=shard_config.fp8_communication, ) # decoder layers @@ -246,6 +248,7 @@ class Qwen2PipelineForwards: hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = gather_forward_split_backward( @@ -253,6 +256,7 @@ class Qwen2PipelineForwards: dim=1, process_group=shard_config.sequence_parallel_process_group, grad_scale=shard_config.sequence_parallel_size, + fp8_communication=shard_config.fp8_communication, ) # add hidden states from the last decoder layer if output_hidden_states: @@ -516,9 +520,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s value_states = self.v_proj(hidden_states) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -604,7 +608,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s attn_output = attn_output.transpose(1, 2).contiguous() if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -702,9 +708,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No next_decoder_cache = None if sp_mode in ["ring", "split_gather"]: - hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) + hidden_states = split_forward_gather_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) + hidden_states = split_forward_gather_backward( + hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) for decoder_layer in self.layers: if output_hidden_states: @@ -741,9 +751,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index b84a372a5..4c33e14bc 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -98,6 +98,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( @@ -106,6 +107,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( @@ -114,6 +116,7 @@ class BertPolicy(Policy): kwargs={ "seq_parallel_mode": sp_mode, "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( @@ -123,7 +126,10 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel_mode": sp_mode}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -136,12 +142,16 @@ class BertPolicy(Policy): "seq_parallel_mode": sp_mode, "overlap": overlap, "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="output.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel_mode": sp_mode}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="output.dropout", @@ -180,6 +190,13 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="word_embeddings", target_module=embedding_cls, + kwargs=( + { + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {} + ), ) ], policy=policy, @@ -249,6 +266,7 @@ class BertPolicy(Policy): kwargs={ "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ), policy=base_policy, diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 32d4edadb..da798f6a0 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -72,20 +72,30 @@ class BlipPolicy(Policy): target_module=col_nn.FusedLinear1D_Col, kwargs={ "n_fused": 3, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="self_attn.projection", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.fc1", target_module=col_nn.Linear1D_Col, - kwargs={"skip_bias_add": self.enable_bias_gelu_fused}, + kwargs={ + "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.fc2", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -114,14 +124,23 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="attention.attention.query", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.attention.key", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.attention.value", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.attention.dropout", @@ -130,6 +149,9 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -138,14 +160,23 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="crossattention.attention.query", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="crossattention.attention.key", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="crossattention.attention.value", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="crossattention.attention.dropout", @@ -154,6 +185,9 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="crossattention.output.dense", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="crossattention.output.dropout", @@ -162,10 +196,16 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="intermediate_query.dense", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="output_query.dense", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="output_query.dropout", @@ -185,26 +225,44 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="fc1", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="fc2", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -225,7 +283,14 @@ class BlipPolicy(Policy): SubModuleReplacementDescription( suffix="model.decoder.embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), ], policy=policy, @@ -241,6 +306,7 @@ class BlipPolicy(Policy): kwargs={ "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ), ], diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index d80adb84a..a43ac02d0 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -76,12 +76,19 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, + kwargs={ + "seq_parallel_mode": sp_mode, + "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel_mode": sp_mode}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attention.attention_dropout", @@ -90,12 +97,19 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, + kwargs={ + "seq_parallel_mode": sp_mode, + "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel_mode": sp_mode}, + kwargs={ + "seq_parallel_mode": sp_mode, + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -115,7 +129,14 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="word_embeddings", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), ], policy=policy, @@ -279,6 +300,7 @@ class BloomForCausalLMPolicy(BloomPolicy): kwargs=dict( gather_output=not self.shard_config.parallel_output, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, + fp8_communication=self.shard_config.fp8_communication, ), ), policy=policy, @@ -337,7 +359,9 @@ class BloomForSequenceClassificationPolicy(BloomPolicy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ), policy=policy, target_key=BloomForSequenceClassification, @@ -374,7 +398,9 @@ class BloomForTokenClassificationPolicy(BloomPolicy): self.append_or_create_submodule_replacement( description=[ SubModuleReplacementDescription( - suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="classifier", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="dropout", diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 3877bdac3..16c13085a 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -128,12 +128,17 @@ class ChatGLMPolicy(Policy): "seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0}, + kwargs={ + "seq_parallel_mode": sp_mode, + "seq_parallel_dim": 0, + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attention.core_attention.attention_dropout", @@ -148,7 +153,14 @@ class ChatGLMPolicy(Policy): SubModuleReplacementDescription( suffix="embedding.word_embeddings", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), ], policy=policy, diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 1efd3d017..323480d6d 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -128,37 +128,37 @@ class CommandPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), ], ) @@ -168,7 +168,14 @@ class CommandPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=CohereModel, @@ -306,6 +313,7 @@ class CommandForCausalLMPolicy(CommandPolicy): kwargs={ "gather_output": not self.shard_config.parallel_output, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ) ], diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index ea68649d5..0b8a602d1 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -118,18 +118,22 @@ class DeepseekPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), ], ) @@ -138,7 +142,10 @@ class DeepseekPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs={ + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + }, ), policy=policy, target_key="DeepseekModel", @@ -155,6 +162,7 @@ class DeepseekPolicy(Policy): "ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, + "fp8_communication": self.shard_config.fp8_communication, }, ) ], @@ -298,14 +306,14 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy): policy = super().module_policy() # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: - # add a new item for causal lm + # add a new item for casual lm new_item = { "DeepseekForCausalLM": ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, - kwargs=dict(gather_output=True), + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index e5c167337..e20fb1568 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -105,7 +105,14 @@ class FalconPolicy(Policy): SubModuleReplacementDescription( suffix="word_embeddings", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), ], policy=policy, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index cfe20000a..0a1949d85 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -110,14 +110,13 @@ class GPT2Policy(Policy): "n_fused": 3, "seq_parallel_mode": sp_mode, "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel_mode": sp_mode, - }, + kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="mlp.c_fc", @@ -127,14 +126,13 @@ class GPT2Policy(Policy): "seq_parallel_mode": sp_mode, "overlap": overlap, "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel_mode": sp_mode, - }, + kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -164,7 +162,14 @@ class GPT2Policy(Policy): description=SubModuleReplacementDescription( suffix="wte", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=GPT2Model, @@ -334,6 +339,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy): kwargs={ "gather_output": False, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ) ], @@ -404,6 +410,7 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): kwargs={ "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ) ] diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index c394d911e..6f0c8803c 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -77,6 +77,7 @@ class GPTJPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( @@ -84,6 +85,7 @@ class GPTJPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( @@ -91,19 +93,29 @@ class GPTJPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "overlap": overlap, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="attn.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.fc_in", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.fc_out", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -125,7 +137,14 @@ class GPTJPolicy(Policy): description=SubModuleReplacementDescription( suffix="wte", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=GPTJModel, @@ -264,6 +283,7 @@ class GPTJForCausalLMPolicy(GPTJPolicy): kwargs={ "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ) ] diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 67e6e92d1..78ac5eaae 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -133,37 +133,37 @@ class LlamaPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), ], ) @@ -173,7 +173,14 @@ class LlamaPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=LlamaModel, @@ -316,6 +323,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy): kwargs={ "gather_output": not self.shard_config.parallel_output, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ) ], @@ -384,7 +392,12 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): LlamaForSequenceClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", + target_module=Linear1D_Col, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + ), ) ] ) diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 6ea27e210..4d16038c1 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -88,30 +88,51 @@ class MistralPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -121,7 +142,14 @@ class MistralPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=MistralModel, @@ -281,6 +309,7 @@ class MistralForCausalLMPolicy(MistralPolicy): kwargs={ "gather_output": not self.shard_config.parallel_output, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ) ] @@ -297,7 +326,9 @@ class MistralForCausalLMPolicy(MistralPolicy): SubModuleReplacementDescription( suffix="lm_head", target_module=PaddingLMHead, - kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), + kwargs=dict( + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, + ), ) ] ) @@ -350,7 +381,9 @@ class MistralForSequenceClassificationPolicy(MistralPolicy): MistralForSequenceClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index e11edae9f..3a373889c 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -114,21 +114,27 @@ class MixtralPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs={"fp8_communication": self.shard_config.fp8_communication}, ), SubModuleReplacementDescription( # or replicate? - suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True} + suffix="block_sparse_moe.gate", + target_module=Linear1D_Col, + kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication}, ), ], ) @@ -138,7 +144,14 @@ class MixtralPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=MixtralModel, @@ -282,7 +295,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy): SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, - kwargs=dict(gather_output=True), + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) @@ -336,7 +349,9 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy): MixtralForSequenceClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 524d2b8cd..dd64ce652 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -102,18 +102,30 @@ class OPTPolicy(Policy): SubModuleReplacementDescription( suffix="q_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="k_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="v_proj", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="out_proj", target_module=Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -123,7 +135,14 @@ class OPTPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=OPTDecoder, @@ -272,6 +291,7 @@ class OPTForCausalLMPolicy(OPTPolicy): kwargs=dict( gather_output=not self.shard_config.parallel_output, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, + fp8_communication=self.shard_config.fp8_communication, ), ), policy=policy, diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 235dc7d56..1b066200d 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -119,37 +119,37 @@ class Qwen2Policy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), ], ) @@ -159,7 +159,14 @@ class Qwen2Policy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=Qwen2Model, @@ -313,11 +320,15 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy): setattr(self.shard_config, "causal_lm", True) if self.shard_config.enable_tensor_parallelism: - # add a new item for causal lm + # add a new item for casual lm new_item = { Qwen2ForCausalLM: ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col) + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(fp8_communication=self.shard_config.fp8_communication), + ) ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, ) @@ -366,7 +377,9 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy): Qwen2ForSequenceClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index 53faf8997..674fe5e58 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -43,19 +43,29 @@ class SamPolicy(Policy): target_module=col_nn.FusedLinear1D_Col, kwargs={ "n_fused": 3, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="attn.proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.lin1", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.lin2", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -68,58 +78,100 @@ class SamPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_token_to_image.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_token_to_image.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_token_to_image.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_token_to_image.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.lin1", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="mlp.lin2", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_image_to_token.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_image_to_token.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_image_to_token.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="cross_attn_image_to_token.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -132,18 +184,30 @@ class SamPolicy(Policy): SubModuleReplacementDescription( suffix="final_attn_token_to_image.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="final_attn_token_to_image.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="final_attn_token_to_image.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="final_attn_token_to_image.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 0b594678c..84b5d9594 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -117,23 +117,38 @@ class T5BasePolicy(Policy): SubModuleReplacementDescription( suffix="q", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="k", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="v", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="o", target_module=Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="relative_attention_bias", target_module=Embedding1D, - kwargs=dict(gather_output=False), + kwargs=dict( + gather_output=False, + fp8_communication=self.shard_config.fp8_communication, + ), ignore_if_not_exist=True, ), ], @@ -151,13 +166,24 @@ class T5BasePolicy(Policy): SubModuleReplacementDescription( suffix="wi_0 ", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="wi_1", target_module=Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( - suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="wo", + target_module=Linear1D_Col, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + ), ), SubModuleReplacementDescription( suffix="dropout", @@ -170,10 +196,16 @@ class T5BasePolicy(Policy): SubModuleReplacementDescription( suffix="wi", target_module=Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="wo", target_module=Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="dropout", @@ -187,7 +219,14 @@ class T5BasePolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=T5Stack, @@ -407,7 +446,14 @@ class T5ModelPolicy(T5BasePolicy): description=SubModuleReplacementDescription( suffix="shared", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=T5Model, @@ -451,7 +497,14 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): description=SubModuleReplacementDescription( suffix="shared", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=T5ForConditionalGeneration, @@ -465,6 +518,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): kwargs={ "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ), policy=policy, @@ -539,7 +593,14 @@ class T5EncoderPolicy(T5BasePolicy): description=SubModuleReplacementDescription( suffix="shared", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=T5EncoderModel, diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 069ad0c26..07202094f 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -70,14 +70,23 @@ class ViTPolicy(Policy): SubModuleReplacementDescription( suffix="attention.attention.query", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.attention.key", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.attention.value", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.attention.dropout", @@ -86,6 +95,9 @@ class ViTPolicy(Policy): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -96,11 +108,15 @@ class ViTPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ "skip_bias_add": self.enable_bias_gelu_fused, + "fp8_communication": self.shard_config.fp8_communication, }, ), SubModuleReplacementDescription( suffix="output.dense", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="output.dropout", @@ -215,7 +231,9 @@ class ViTForImageClassificationPolicy(ViTPolicy): ViTForImageClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="classifier", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), ) ] ) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 441e512bb..7a1f146d5 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -91,26 +91,44 @@ class WhisperPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="fc1", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="fc2", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -128,42 +146,72 @@ class WhisperPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="self_attn.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="encoder_attn.q_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="encoder_attn.k_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="encoder_attn.v_proj", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="encoder_attn.out_proj", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="fc1", target_module=col_nn.Linear1D_Col, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), SubModuleReplacementDescription( suffix="fc2", target_module=col_nn.Linear1D_Row, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + }, ), ], ) @@ -174,7 +222,14 @@ class WhisperPolicy(Policy): SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), ], policy=policy, @@ -303,6 +358,7 @@ class WhisperPolicy(Policy): kwargs={ "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, }, ), policy=base_policy, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 70eb271c9..1219119bb 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -29,6 +29,7 @@ class ShardConfig: enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. + fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False. parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim. For SP: set to True to NOT gather the output along the seq dim. """ @@ -54,6 +55,7 @@ class ShardConfig: # for moe related moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None + fp8_communication: bool = False # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index acb9fc4ae..8992b89a3 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -61,6 +61,8 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): with torch._C.DisableTorchFunction(): new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) args, kwargs = replace_args(args, kwargs, new_args) + with torch._C.DisableTorchFunction(): + func = ColoParamOpHookManager.rewrite_op(func) ret = super().__torch_function__(func, types, args, kwargs) with torch._C.DisableTorchFunction(): ret = ColoParamOpHookManager.post_op(params, ret) diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 40de43c43..c8dd5a0c8 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -30,6 +30,9 @@ class ColoParamOpHook(ABC): def post_backward(self, params: List[torch.Tensor]) -> None: pass + def rewrite_op(self, func) -> Any: + return func + class ColoParamOpHookManager: """ @@ -101,6 +104,12 @@ class ColoParamOpHookManager: def has_hook() -> bool: return len(ColoParamOpHookManager.hooks) > 0 + @staticmethod + def rewrite_op(func) -> Any: + for hook in ColoParamOpHookManager.hooks: + func = hook.rewrite_op(func) + return func + class PreFwdPostBwd(torch.autograd.Function): @staticmethod diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 969df9621..e2b7a8f56 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -166,6 +166,7 @@ class Chunk: self.grad_chunk = None # the async all-reduce/reduce-scatter work of this grad chunk (None means sync) self.grad_reduce_work = None + self.fp8_communication = False @property def memory_usage(self) -> Dict[str, int]: @@ -521,9 +522,17 @@ class Chunk: alloc_storage(self.cuda_global_chunk) assert self.cuda_global_chunk.is_contiguous() - work = dist.all_gather_into_tensor( - self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op - ) + if self.fp8_communication: + assert async_op == False, "fp8 all-gather does not support async_op!" + from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8 + + work = all_gather_into_tensor_flat_fp8( + self.cuda_global_chunk, self.cuda_shard, self.cuda_global_chunk.shape, self.torch_pg + ) + else: + work = dist.all_gather_into_tensor( + self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op + ) self.cuda_shard = None self.is_gathered = True diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index d0e1755f4..06f9b6d18 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -26,6 +26,7 @@ class ChunkManager: init_device: Optional[torch.device] = None, reuse_fp16_chunk: bool = True, max_prefetch: int = 0, + fp8_communication: bool = False, ) -> None: self.device = init_device or get_accelerator().get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() @@ -44,6 +45,7 @@ class ChunkManager: self.accumulating_grads = False self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device()) self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None + self.fp8_communication = fp8_communication def register_tensor( self, @@ -101,6 +103,8 @@ class ChunkManager: extra_dp_group=extra_dp_group, **chunk_kwargs, ) + if self.fp8_communication: + chunk.fp8_communication = True chunk_group.append(chunk) chunk.append_tensor(tensor) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 80b2c7961..dbaae6610 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_ from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger +from colossalai.quantization.fp8_hook import FP8Hook from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.d_tensor import ( distribute_tensor, @@ -98,6 +99,8 @@ class GeminiDDP(ModelWrapper): extra_dp_group: Optional[ProcessGroup] = None, verbose: bool = False, enable_async_reduce: bool = True, + fp8_communication: bool = False, + use_fp8: bool = False, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False @@ -122,6 +125,8 @@ class GeminiDDP(ModelWrapper): verbose=verbose, max_prefetch=max_prefetch, ) + if fp8_communication: + self.chunk_manager.fp8_communication = True self.gemini_manager = GeminiManager( placement_policy, self.chunk_manager, @@ -135,6 +140,9 @@ class GeminiDDP(ModelWrapper): ) self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(self.gemini_manager) + self.hooks = [self.param_op_hook] + if use_fp8: + self.hooks.append(FP8Hook()) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() @@ -307,7 +315,7 @@ class GeminiDDP(ModelWrapper): outputs = self._inference_forward(*args, **kwargs) else: self.gemini_manager.pre_iter(*args) - with ColoParamOpHookManager.use_hooks(self.param_op_hook): + with ColoParamOpHookManager.use_hooks(*self.hooks): outputs = self.module(*args, **kwargs) if self.force_outputs_fp32: @@ -316,7 +324,7 @@ class GeminiDDP(ModelWrapper): def _inference_forward(self, *args, **kwargs): """This function is only triggered for inference.""" - fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook) + fwd_ctx = ColoParamOpHookManager.use_hooks(*self.hooks) if not self.scatter_after_inference: # gather all chunks for chunk in self.chunk_manager.get_chunks(self.fp16_params): @@ -369,7 +377,7 @@ class GeminiDDP(ModelWrapper): def backward(self, loss: torch.Tensor): self._pre_backward() - with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(*self.hooks): loss.backward() self._post_backward() diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index 5b09019b9..d5fd2fe51 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -4,6 +4,8 @@ import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8 + class TensorBucket: def __init__(self, size): @@ -61,11 +63,14 @@ class TensorBucket: for old, new in zip(self._bucket, unflattened_tensor_list): old.copy_(new) - def all_gather(self, group=None): + def all_gather(self, group=None, fp8_communication: bool = False): flat = self.flatten() - buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))] - dist.all_gather(buffers, flat, group=group) - unflat_buffers = [self.unflatten(buffer) for buffer in buffers] + buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype) + if fp8_communication: + all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group) + else: + dist.all_gather_into_tensor(buffer, flat, group=group) + unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))] # transpose the list of list unflat_buffers = list(map(list, zip(*unflat_buffers))) for unflat_shards, tensor in zip(unflat_buffers, self._bucket): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 51d7d1eaa..458e6e41a 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -20,6 +20,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import ( ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger +from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8 from colossalai.tensor.moe_tensor.api import is_moe_tensor from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor @@ -86,6 +87,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): forced_dtype: Optional[torch.dtype] = None, master_weights: bool = True, # master weights overlap_allgather: bool = False, + fp8_communication: bool = False, ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) @@ -127,6 +129,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._overlap_allgather = overlap_allgather self._reduce_bucket_size = reduce_bucket_size self._communication_dtype = communication_dtype + self._fp8_communication = fp8_communication # gradient clipping self._clip_grad_norm = clip_grad_norm @@ -330,7 +333,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): flat_grads = flat_grads.to(self._communication_dtype) if not self._partition_grads: - dist.all_reduce(flat_grads, group=bucket_store.torch_pg) + if self._fp8_communication: + all_reduce_fp8(flat_grads, group=bucket_store.torch_pg) + else: + dist.all_reduce(flat_grads, group=bucket_store.torch_pg) if flat_grads.dtype != grad_dtype: flat_grads = flat_grads.to(grad_dtype) @@ -340,7 +346,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): else: flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size)) received_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) + if self._fp8_communication: + reduce_scatter_fp8( + received_grad, + flat_grads_list, + group=bucket_store.torch_pg, + ) + else: + dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) if received_grad.dtype != grad_dtype: received_grad = received_grad.to(grad_dtype) @@ -562,18 +575,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper): set_all_gather_handle(working_param, handle) else: if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: - dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) + if self._fp8_communication: + all_gather_into_tensor_flat_fp8( + padded_working_param, param_to_gather, pg, fp8_format="e4m3" + ) + else: + dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg) continue try: self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) except RuntimeError: - self.pg_to_tensor_bucket[pg].all_gather(pg) + self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication) self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] if not self._overlap_allgather: for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): if not tensor_bucket.is_empty(): - tensor_bucket.all_gather(pg) + tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: r""" diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 7e8c07fdc..f048abdd2 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -179,7 +179,7 @@ def main(): "--plugin", type=str, default="torch_ddp", - choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"], + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel", "torch_fsdp"], help="plugin to use", ) parser.add_argument( @@ -190,6 +190,7 @@ def main(): ) parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") + parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication") args = parser.parse_args() if args.model_type == "bert": @@ -214,9 +215,9 @@ def main(): if args.plugin == "torch_ddp_fp16": booster_kwargs["mixed_precision"] = "fp16" if args.plugin.startswith("torch_ddp"): - plugin = TorchDDPPlugin() + plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm) elif args.plugin == "gemini": - plugin = GeminiPlugin(initial_scale=2**5) + plugin = GeminiPlugin(initial_scale=2**5, fp8_communication=args.use_fp8_comm) elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) elif args.plugin == "hybrid_parallel": @@ -232,6 +233,18 @@ def main(): zero_stage=1, precision="fp16", initial_scale=1, + fp8_communication=args.use_fp8_comm, + ) + elif args.plugin == "torch_fsdp": + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + + from colossalai.booster.plugin import TorchFSDPPlugin + + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + fp8_communication=args.use_fp8_comm, ) booster = Booster(plugin=plugin, **booster_kwargs) diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index ae6d655f4..e9f7203e9 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -188,6 +188,8 @@ def main(): help="only gpt2 now", ) parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") + parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication") args = parser.parse_args() if args.model_type == "gpt2": @@ -210,7 +212,7 @@ def main(): if args.plugin == "torch_ddp_fp16": booster_kwargs["mixed_precision"] = "fp16" if args.plugin.startswith("torch_ddp"): - plugin = TorchDDPPlugin() + plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm) elif args.plugin == "gemini": plugin = GeminiPlugin(initial_scale=2**5) elif args.plugin == "low_level_zero": @@ -226,6 +228,7 @@ def main(): zero_stage=1, precision="fp16", initial_scale=1, + fp8_communication=args.use_fp8_comm, ) booster = Booster(plugin=plugin, **booster_kwargs) diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 093377e7a..bb14378ad 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -104,6 +104,8 @@ def main(): parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--no_cache", action="store_true") + parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") + parser.add_argument("--use_fp8", action="store_true") parser.add_argument("--overlap_allgather", action="store_true") parser.add_argument( "--sp_mode", @@ -148,6 +150,7 @@ def main(): enable_flash_attention=args.xformers, max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, + use_fp8=args.use_fp8, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -160,6 +163,7 @@ def main(): max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, enable_flash_attention=args.xformers, + use_fp8=args.use_fp8, ) elif args.plugin == "fsdp": if use_empty_init: @@ -170,6 +174,7 @@ def main(): buffer_dtype=torch.float16, ), param_init_fn=empty_init(), + fp8_communication=args.use_fp8_comm, ) else: plugin = TorchFSDPPlugin( @@ -177,7 +182,8 @@ def main(): param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, - ) + ), + fp8_communication=args.use_fp8_comm, ) elif args.plugin == "fsdp_cpu": if use_empty_init: @@ -189,6 +195,7 @@ def main(): ), cpu_offload=CPUOffload(offload_params=True), param_init_fn=empty_init(), + fp8_communication=args.use_fp8_comm, ) else: plugin = TorchFSDPPlugin( @@ -198,6 +205,7 @@ def main(): buffer_dtype=torch.float16, ), cpu_offload=CPUOffload(offload_params=True), + fp8_communication=args.use_fp8_comm, ) elif args.plugin == "3d": plugin = HybridParallelPlugin( @@ -215,6 +223,7 @@ def main(): precision="bf16", enable_metadata_cache=not args.no_cache, overlap_allgather=args.overlap_allgather, + use_fp8=args.use_fp8, **hybrid_kwargs, ) elif args.plugin == "3d_cpu": @@ -230,6 +239,8 @@ def main(): microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", + overlap_p2p=args.overlap, + use_fp8=args.use_fp8, ) else: raise ValueError(f"Unknown plugin {args.plugin}") @@ -259,7 +270,6 @@ def main(): if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) else nullcontext() ) - init_kwargs = {} if config.model_type == "chatglm": init_kwargs["empty_init"] = False diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 93a3690fe..3fcf53e18 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -9,7 +9,7 @@ torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package torchrec==0.2.0 contexttimer einops -triton==2.1.0 +triton requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece ninja diff --git a/tests/test_fp8/test_all_to_all_single.py b/tests/test_fp8/test_all_to_all_single.py new file mode 100644 index 000000000..722cbce9a --- /dev/null +++ b/tests/test_fp8/test_all_to_all_single.py @@ -0,0 +1,75 @@ +import torch +import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import all_to_all_single_fp8 +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +@parameterize("async_op", [True, False]) +def check_all2all(shape, dtype, async_op): + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + output = torch.empty_like(x) + output_fp8 = torch.empty_like(x) + origin_hanle = dist.all_to_all_single(output, x, group=_get_default_group(), async_op=async_op) + fp8_handle = all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=async_op) + if async_op: + origin_hanle.wait() + fp8_handle.wait() + assert_close(output, output_fp8, rtol=0.1, atol=0.1) + + +@parameterize("shape", [(8, 8, 16)]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +@parameterize("async_op", [True, False]) +def check_all2all_uneven(shape, dtype, async_op): + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + input_split_sizes = [3, 3, 1, 1] + if dist.get_rank() in [0, 1]: + output_split_sizes = [3, 3, 3, 3] + else: + output_split_sizes = [1, 1, 1, 1] + output_shape = list(shape) + output_shape[0] = sum(output_split_sizes) + output = torch.empty(output_shape, device=x.device, dtype=x.dtype) + output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype) + origin_hanle = dist.all_to_all_single( + output, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=_get_default_group(), + async_op=async_op, + ) + fp8_handle = all_to_all_single_fp8( + output_fp8, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=_get_default_group(), + async_op=async_op, + ) + if async_op: + origin_hanle.wait() + fp8_handle.wait() + assert_close(output, output_fp8, rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_all2all() + check_all2all_uneven() + + +@rerun_if_address_is_in_use() +def test_all_to_all_single(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_all_to_all_single() diff --git a/tests/test_fp8/test_fp8_all_to_all.py b/tests/test_fp8/test_fp8_all_to_all.py new file mode 100644 index 000000000..884aab744 --- /dev/null +++ b/tests/test_fp8/test_fp8_all_to_all.py @@ -0,0 +1,39 @@ +import torch +import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import all_to_all_fp8 +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("shape", [(16, 8, 4)]) +@parameterize("scatter_dim", [0, 1, 2]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +@parameterize("fp8_format", ["e4m3", "e5m2"]) +def check_4gpu(shape, scatter_dim, dtype, fp8_format): + world_size = dist.get_world_size() + input_tensor = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + input_tensor_list = list(torch.chunk(input_tensor, world_size, scatter_dim)) + input_tensor_list = [x.contiguous() for x in input_tensor_list] + output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list] + output_tensor_list = [torch.empty_like(x) for x in input_tensor_list] + all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format) + dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group()) + assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_4gpu() + + +@rerun_if_address_is_in_use() +def test_all_to_all(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_all_to_all() diff --git a/tests/test_fp8/test_fp8_all_to_all_single.py b/tests/test_fp8/test_fp8_all_to_all_single.py new file mode 100644 index 000000000..70765f2d4 --- /dev/null +++ b/tests/test_fp8/test_fp8_all_to_all_single.py @@ -0,0 +1,37 @@ +import torch +import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import all_to_all_single_fp8 +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + +dist.all_to_all_single + + +@parameterize("shape", [(4), (8, 7), (4, 8, 16)]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +@parameterize("fp8_format", ["e4m3", "e5m2"]) +def check_4gpu(shape, dtype, fp8_format): + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + output = torch.empty_like(x) + output_fp8 = torch.empty_like(x) + all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), fp8_format=fp8_format) + dist.all_to_all_single(output, x, group=_get_default_group()) + assert_close(output, output_fp8, rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_4gpu() + + +@rerun_if_address_is_in_use() +def test_all_to_all_single(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_all_to_all_single() diff --git a/tests/test_fp8/test_fp8_allgather_flat.py b/tests/test_fp8/test_fp8_allgather_flat.py new file mode 100644 index 000000000..2d43e5bd5 --- /dev/null +++ b/tests/test_fp8/test_fp8_allgather_flat.py @@ -0,0 +1,43 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed.distributed_c10d import _get_default_group +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8 +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +@parameterize("async_op", [True, False]) +def check_4gpu(shape, dtype, async_op): + world_size = dist.get_world_size() + rank = dist.get_rank() + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + flat_padded_x = x.view(-1) + if flat_padded_x.size(0) % world_size != 0: + pad_size = world_size - flat_padded_x.size(0) % world_size + flat_padded_x = F.pad(flat_padded_x, (0, pad_size)) + output = torch.empty_like(flat_padded_x) + chunk = flat_padded_x.chunk(world_size)[rank].clone() + handle = all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group(), async_op=async_op) + if async_op: + handle.wait() + assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_4gpu() + + +@rerun_if_address_is_in_use() +def test_all_gather_flat(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_all_gather_flat() diff --git a/tests/test_fp8/test_fp8_allreduce.py b/tests/test_fp8/test_fp8_allreduce.py new file mode 100644 index 000000000..ccc43ed29 --- /dev/null +++ b/tests/test_fp8/test_fp8_allreduce.py @@ -0,0 +1,55 @@ +import torch +import torch.distributed as dist +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import all_reduce_fp8 +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize( + "shape", + [ + (3, 7), + (4, 7), + (7, 4), + (8, 9), + (3), + (7,), + (8,), + ], +) +@parameterize("dtype", [torch.float16, torch.bfloat16]) +@parameterize("fp8_format", ["e4m3", "e5m2"]) +@parameterize("async_op", [True, False]) +def check_4gpu(shape, dtype, fp8_format, async_op): + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + x_fp8 = x.clone() + origin_handle = dist.all_reduce(x, async_op=async_op) + fp8_handle = all_reduce_fp8(x_fp8, fp8_format=fp8_format, async_op=async_op) + if async_op: + origin_handle.wait() + fp8_handle.wait() + assert_close(x, x_fp8, rtol=0.1, atol=0.1) + + origin_handle = dist.all_reduce(x, op=dist.ReduceOp.AVG, async_op=async_op) + fp8_handle = all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format, async_op=async_op) + if async_op: + origin_handle.wait() + fp8_handle.wait() + assert_close(x, x_fp8, rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_4gpu() + + +@rerun_if_address_is_in_use() +def test_all_reduce(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_all_reduce() diff --git a/tests/test_fp8/test_fp8_cast.py b/tests/test_fp8/test_fp8_cast.py new file mode 100644 index 000000000..db9a909e6 --- /dev/null +++ b/tests/test_fp8/test_fp8_cast.py @@ -0,0 +1,26 @@ +import torch +from torch.testing import assert_close + +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline +from colossalai.testing import parameterize + + +@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)]) +@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@parameterize("fp8_format", ["e4m3", "e5m2"]) +def test_fp8_cast(shape, dtype, fp8_format): + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format) + out = cast_from_fp8(ret, scale_inv, x.dtype) + assert_close(out, x, rtol=0.1, atol=0.1) + + if x.size(-1) % 2 == 0: + inp_dict = {"hidden_states": x.clone()} + cast_to_fp8_pipeline(inp_dict) + cast_from_fp8_pipeline(inp_dict) + assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1) + + +if __name__ == "__main__": + test_fp8_cast() diff --git a/tests/test_fp8/test_fp8_ddp_comm_hook.py b/tests/test_fp8/test_fp8_ddp_comm_hook.py new file mode 100644 index 000000000..9bdfe17a1 --- /dev/null +++ b/tests/test_fp8/test_fp8_ddp_comm_hook.py @@ -0,0 +1,87 @@ +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html + + +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +def demo_basic(rank, world_size): + print(f"Running basic DDP example on rank {rank}.") + setup(rank, world_size) + + def get_grads_after_one_iteration(hook=None): + torch.manual_seed(0) + # create model and move it to GPU with id rank + model = ToyModel().to(rank) + + ddp_model = DDP(model, device_ids=[rank]) + + if hook is not None: + ddp_model.register_comm_hook(None, hook) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) + + optimizer.zero_grad() + outputs = ddp_model(torch.randn(20, 10)) + labels = torch.randn(20, 5).to(rank) + loss_fn(outputs, labels).backward() + optimizer.step() + + torch.distributed.barrier() + + grad_dict = {} + for name, params in ddp_model.named_parameters(): + grad_dict[name] = params.grad + return grad_dict + + from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async, fp8_compress_ddp_grad_comm_hook_sync + + grad_dict = get_grads_after_one_iteration() + for hook in [fp8_compress_ddp_grad_comm_hook_sync, fp8_compress_ddp_grad_comm_hook_async]: + grad_dict_w_hook = get_grads_after_one_iteration(hook) + if dist.get_rank() == 0: + for name in grad_dict: + assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1) + + cleanup() + + +def run_demo(demo_fn, world_size): + mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True) + + +if __name__ == "__main__": + n_gpus = torch.cuda.device_count() + assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" + world_size = n_gpus + run_demo(demo_basic, world_size) diff --git a/tests/test_fp8/test_fp8_fsdp_comm_hook.py b/tests/test_fp8/test_fp8_fsdp_comm_hook.py new file mode 100644 index 000000000..3d0660961 --- /dev/null +++ b/tests/test_fp8/test_fp8_fsdp_comm_hook.py @@ -0,0 +1,107 @@ +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim +from packaging import version +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.testing import assert_close + +from colossalai import launch +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + +# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html + + +def cleanup(): + dist.destroy_process_group() + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(100, 100) + self.relu = nn.ReLU() + self.net2 = nn.Linear(100, 50) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +@parameterize("mode", ["grad", "params"]) +def run_model(mode): + rank = dist.get_rank() + + from colossalai.quantization.utils import patch_fsdp_params_comm_hook + + patch_fsdp_params_comm_hook() + + def get_grads_after_one_iteration(grad_hook=None, params_hook=None): + torch.manual_seed(0) + # create model and move it to GPU with id rank + model = ToyModel().to(rank) + fsdp_model = FSDP(model) + + if grad_hook is not None: + fsdp_model.register_comm_hook(None, grad_hook) + + if params_hook is not None: + fsdp_model.register_params_comm_hook(None, params_hook) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(fsdp_model.parameters(), lr=0.001) + + optimizer.zero_grad() + outputs = fsdp_model(torch.randn(20, 100)) + labels = torch.randn(20, 50).to(rank) + loss_fn(outputs, labels).backward() + optimizer.step() + + torch.distributed.barrier() + + grad_dict = {} + for name, params in fsdp_model.named_parameters(): + grad_dict[name] = params.grad + return grad_dict + + from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook, fp8_compress_fsdp_params_comm_hook + + if mode == "grad": + grad_dict = get_grads_after_one_iteration() + for hook in [ + fp8_compress_fsdp_grad_comm_hook, + ]: + grad_dict_w_hook = get_grads_after_one_iteration(grad_hook=hook) + if dist.get_rank() == 0: + for name in grad_dict: + assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1) + elif mode == "params": + grad_dict = get_grads_after_one_iteration() + for hook in [ + fp8_compress_fsdp_params_comm_hook, + ]: + grad_dict_w_hook = get_grads_after_one_iteration(params_hook=hook) + if dist.get_rank() == 0: + for name in grad_dict: + assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1) + else: + raise NotImplementedError + + +def demo_basic(rank, world_size, port): + print(f"Running basic FSDP example on rank {rank}.") + launch(rank=rank, world_size=world_size, port=port, host="localhost") + run_model() + cleanup() + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("2.2.0"), reason="torch version < 2.2.0.") +@rerun_if_address_is_in_use() +def test_fsdp(): + n_gpus = torch.cuda.device_count() + assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" + spawn(demo_basic, n_gpus) + + +if __name__ == "__main__": + test_fsdp() diff --git a/tests/test_fp8/test_fp8_gather.py b/tests/test_fp8/test_fp8_gather.py new file mode 100644 index 000000000..40c2ccb9a --- /dev/null +++ b/tests/test_fp8/test_fp8_gather.py @@ -0,0 +1,52 @@ +import torch +import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import gather_fp8 +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize( + "shape", + [ + (3, 7), + (2, 1), + (1, 2), + (2, 2), + (4, 2), + (5,), + (4,), + (2,), + ], +) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +@parameterize("fp8_format", ["e4m3", "e5m2"]) +@parameterize("async_op", [True, False]) +def check_4gpu(shape, dtype, fp8_format, async_op): + world_size = dist.get_world_size() + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + output_list = [torch.empty_like(x) for _ in range(world_size)] + output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)] + fp8_handle = gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op) + origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op) + if async_op: + fp8_handle.wait() + origin_hanle.wait() + assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_4gpu() + + +@rerun_if_address_is_in_use() +def test_all_gather(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_all_gather() diff --git a/tests/test_fp8/test_fp8_hook.py b/tests/test_fp8/test_fp8_hook.py new file mode 100644 index 000000000..abd5d09e1 --- /dev/null +++ b/tests/test_fp8/test_fp8_hook.py @@ -0,0 +1,50 @@ +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import linear_fp8 +from colossalai.quantization.fp8_hook import FP8Hook +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.utils import get_current_device + +REPLACED = False +TRIGGERED = False + + +def new_linear_fp8(x, w, bias=None): + global TRIGGERED + TRIGGERED = True + return linear_fp8(x, w, bias) + + +class FP8TestHook(FP8Hook): + def rewrite_op(self, func): + func = super().rewrite_op(func) + if func is linear_fp8: + global REPLACED + REPLACED = True + return new_linear_fp8 + return func + + +D_IN, D_OUT = 16, 32 +B, S = 2, 64 +DTYPE = torch.bfloat16 + + +@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") +def test_fp8_hook(): + # create tensors + w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE)) + x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True) + w.__class__ = ColoParameter + w.__init__(w, requires_grad=True) + hook = FP8TestHook() + with ColoParamOpHookManager.use_hooks(hook): + o = F.linear(x, w) + assert o.shape == (B, S, D_OUT) + assert REPLACED + assert TRIGGERED diff --git a/tests/test_fp8/test_fp8_linear.py b/tests/test_fp8/test_fp8_linear.py new file mode 100644 index 000000000..d035957f2 --- /dev/null +++ b/tests/test_fp8/test_fp8_linear.py @@ -0,0 +1,45 @@ +import pytest +import torch +import torch.nn.functional as F +from torch.testing import assert_close + +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import linear_fp8 +from colossalai.utils import get_current_device + +D_IN, D_OUT = 16, 32 +B, S = 2, 64 +DTYPE = torch.bfloat16 + + +@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_batch", [True, False]) +def test_fp8_linear(use_bias: bool, use_batch: bool): + # create tensors + w = torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True) + ref_w = w.clone().detach().requires_grad_() + if use_batch: + x_shape = (B, S, D_IN) + else: + x_shape = (S, D_IN) + x = torch.rand(x_shape, device=get_current_device(), dtype=DTYPE, requires_grad=True) + ref_x = x.clone().detach().requires_grad_() + if use_bias: + bias = torch.rand(D_OUT, device=get_current_device(), dtype=DTYPE, requires_grad=True) + ref_bias = bias.clone().detach().requires_grad_() + else: + bias = None + ref_bias = None + + out = linear_fp8(x, w, bias) + assert out.shape == x_shape[:-1] + (D_OUT,) + out.sum().backward() + ref_out = F.linear(ref_x, ref_w, ref_bias) + ref_out.sum().backward() + + assert_close(out, ref_out, rtol=0.2, atol=0.1) + assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1) + assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1) + if use_bias: + assert_close(bias.grad, ref_bias.grad, rtol=0.2, atol=0.1) diff --git a/tests/test_fp8/test_fp8_reduce_scatter.py b/tests/test_fp8/test_fp8_reduce_scatter.py new file mode 100644 index 000000000..e0b558a25 --- /dev/null +++ b/tests/test_fp8/test_fp8_reduce_scatter.py @@ -0,0 +1,44 @@ +import torch +from torch.distributed import reduce_scatter +from torch.distributed.distributed_c10d import _get_default_group +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import reduce_scatter_fp8 +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("shape", [(16, 8, 4)]) +@parameterize("scatter_dim", [0, 1, 2]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +@parameterize("fp8_format", ["e4m3", "e5m2"]) +@parameterize("async_op", [True, False]) +def check_4gpu(shape, scatter_dim, dtype, fp8_format, async_op): + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + input_list = list(torch.chunk(x, dim=scatter_dim, chunks=4)) + input_list = [t.contiguous() for t in input_list] + output_origin = torch.empty_like(input_list[0]) + output_fp8 = torch.empty_like(input_list[0]) + origin_handle = reduce_scatter(output_origin, input_list, group=_get_default_group(), async_op=async_op) + fp8_handle = reduce_scatter_fp8( + output_fp8, input_list, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op + ) + if async_op: + origin_handle.wait() + fp8_handle.wait() + assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_4gpu() + + +@rerun_if_address_is_in_use() +def test_reduce_scatter(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_reduce_scatter() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index c376c50e0..368c782fe 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -51,7 +51,8 @@ def split_ddp_grad(grad, world_size): return splited_grad -def exam_zero_1_2(): +@parameterize("fp8_communication", [True, False]) +def exam_zero_1_2(fp8_communication: bool): """ In this test, we want to test whether zero stage 1 and 2 deliver the same numerical results despite different communication @@ -73,10 +74,18 @@ def exam_zero_1_2(): zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) zero1_optimizer = LowLevelZeroOptimizer( - zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True + zero1_optimizer, + overlap_communication=True, + initial_scale=128, + verbose=True, + fp8_communication=fp8_communication, ) zero2_optimizer = LowLevelZeroOptimizer( - zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128 + zero2_optimizer, + overlap_communication=True, + partition_grad=True, + initial_scale=128, + fp8_communication=fp8_communication, ) # create data seed_all(2001 + local_rank) @@ -97,7 +106,10 @@ def exam_zero_1_2(): if g1 is None or g2 is None: assert g1 is None and g2 is None continue - assert torch.allclose(g1, g2) + if fp8_communication: + loose_close(g1, g2, dtype=torch.float16) + else: + assert torch.allclose(g1, g2) # step zero1_optimizer.step() @@ -105,7 +117,8 @@ def exam_zero_1_2(): # check updated param for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): - assert torch.allclose(z1p, z2p) + if not fp8_communication: + assert torch.allclose(z1p, z2p) @parameterize("dtype", [torch.float16, torch.bfloat16])