From a11b4b50a78e0f7754c406c3a982863fee71ac58 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 14 Oct 2024 07:12:14 +0000 Subject: [PATCH] [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy; --- .../booster/plugin/hybrid_parallel_plugin.py | 1 + .../plugin/moe_hybrid_parallel_plugin.py | 1 + colossalai/pipeline/weight_grad_store.py | 70 ------------------- colossalai/shardformer/policies/llama.py | 42 +++++++++-- colossalai/shardformer/policies/mixtral.py | 40 ++++++++--- colossalai/shardformer/shard/shard_config.py | 1 + 6 files changed, 70 insertions(+), 85 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 5561533e1..673701017 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1217,6 +1217,7 @@ class HybridParallelPlugin(PipelinePluginBase): gradient_checkpoint_config=gradient_checkpoint_config, fp8_communication=fp8_communication, inner_ring_size=inner_ring_size, + use_zbv=(pp_style == "zbv"), ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 8b62a1e2b..b7e65c6a2 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -373,6 +373,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, fp8_communication=fp8_communication, + use_zbv=(pp_style == "zbv"), ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/pipeline/weight_grad_store.py b/colossalai/pipeline/weight_grad_store.py index 5d7f76649..12963350f 100644 --- a/colossalai/pipeline/weight_grad_store.py +++ b/colossalai/pipeline/weight_grad_store.py @@ -34,73 +34,3 @@ class WeightGradStore: weight.grad = grad_weight else: raise Exception("Pop empty queue.") - - # @classmethod - # def clear(cls, model, chunk=0): - # weight_grad_tasks = [] - # while cls.weight_grad_queue[chunk].qsize() > 0: - # stored_grads = cls.weight_grad_queue[chunk].get() - # if len(weight_grad_tasks) == 0: - # for _ in stored_grads: - # weight_grad_tasks.append([]) - # else: - # assert len(weight_grad_tasks) == len(stored_grads) - # for i, task in enumerate(stored_grads): - # weight_grad_tasks[i].append(task) - # weight_params = [] - # handles = [] - # if get_args().overlap_grad_reduce: - # handles += model.async_reduce_grad() - - # output_layer_weight = None - # if parallel_state.is_pipeline_last_stage(): - # assert len(weight_grad_tasks) > 0 - # output_layer_grads = weight_grad_tasks[0] - # for j in range(len(output_layer_grads)): - # total_input, grad_output, weight, func = output_layer_grads[j] - # if output_layer_weight is None: - # output_layer_weight = weight - # assert output_layer_weight is weight - # func(total_input, grad_output, weight.main_grad) - # output_layer_grads[j] = None # release memory - # weight_grad_tasks = weight_grad_tasks[1:] - # if get_args().overlap_grad_reduce: - # handles += model.async_reduce_grad(output_layer_weight) - - # if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage(): - # model_module = get_attr_wrapped_model(model, 'pre_process', return_model_obj=True) - # if model_module.share_embeddings_and_output_weights: - # # if share_embeddings_and_output_weights, wait all-reduce for embeddings - # for handle in handles: - # if handle is not None: - # handle.wait() - # handles = [] - - # config = get_model_config(model) - # # Do async all-reduce for embedding grads firstly, so that the rank 0 won't - # # be blocked - # embedding_handles = _allreduce_embedding_grads([model], config, async_op=True) - # handles += embedding_handles - - # for i in range(len(weight_grad_tasks)): - # tasks = weight_grad_tasks[i] - # param = None - # for j in range(len(tasks)): - # total_input, grad_output, weight, func = tasks[j] - # if param is None: - # param = weight - # assert param is weight - # assert not (weight is output_layer_weight) - # func(total_input, grad_output, weight.main_grad) - # tasks[j] = None # release memory - # weight_params.append(param) - # if get_args().overlap_grad_reduce: - # # All-reduce param grad here - # handles += model.async_reduce_grad(param) - # weight_grad_tasks[i] = None # release memory - - # # timers('wait_all_reduce', log_level=1).start(barrier=False) - # for handle in embedding_handles: - # if handle is not None: - # handle.wait() - # # timers('wait_all_reduce').stop() diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index f9897b8b7..5d48a16c3 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -126,37 +126,65 @@ class LlamaPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=self.shard_config.use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=self.shard_config.use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=self.shard_config.use_zbv, + ), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=self.shard_config.use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=self.shard_config.use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=self.shard_config.use_zbv, + ), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + seq_parallel_mode=sp_mode, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=self.shard_config.use_zbv, + ), ), ], ) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 9d8d2b54b..705f2b19f 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -124,27 +124,43 @@ class MixtralPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": self.shard_config.use_zbv, + }, ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": self.shard_config.use_zbv, + }, ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": self.shard_config.use_zbv, + }, ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs={"fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": self.shard_config.use_zbv, + }, ), SubModuleReplacementDescription( suffix="block_sparse_moe.gate", target_module=Linear1D_Col, - kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication}, + kwargs={ + "gather_output": True, + "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": self.shard_config.use_zbv, + }, ), ], ) @@ -322,9 +338,13 @@ class MixtralForCausalLMPolicy(MixtralPolicy): SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, - kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=self.shard_config.use_zbv, + ), ) - ] + ], ) } policy.update(new_item) @@ -380,7 +400,11 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy): SubModuleReplacementDescription( suffix="score", target_module=Linear1D_Col, - kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=self.shard_config.use_zbv, + ), ) ] ) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 1219119bb..33e93fa51 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -49,6 +49,7 @@ class ShardConfig: make_vocab_size_divisible_by: int = 64 gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) + use_zbv: bool = False # For ring attention inner_ring_size: Optional[int] = None