From 5c2ebbfd48ad7590b0278687db2e41ab99e398d4 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 15 Nov 2024 05:58:56 +0000 Subject: [PATCH] [fix] fix mixtral modeling & policy; update wait handles; doing benchmarking for llama hybrid; --- colossalai/pipeline/schedule/zero_bubble_pp.py | 13 +++++++++++-- colossalai/shardformer/modeling/mixtral.py | 1 - colossalai/shardformer/policies/mixtral.py | 2 -- examples/language/mixtral/benchmark.py | 2 +- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 31e6cfb38..97ad9d5f5 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -46,7 +46,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, - overlap_p2p: bool = False, + overlap_p2p: bool = True, ): super().__init__(stage_manager) # Not support overlap_p2p so far @@ -879,12 +879,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) for it in range(len(schedule)): scheduled_node = schedule[it] + # print(f"stage {self.stage_manager.stage} {scheduled_node.type}") if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication communication_func = self.communication_map[scheduled_node.type] wait_handle = communication_func(scheduled_node.chunk) self.wait_handles.append(wait_handle) elif scheduled_node.type == "F": + for h in self.wait_handles: + for hh in h: + hh.wait() self.schedule_f( scheduled_node=scheduled_node, model_chunk=model_chunk, @@ -894,6 +898,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): outputs=outputs, ) elif scheduled_node.type == "B": + for h in self.wait_handles: + for hh in h: + hh.wait() self.schedule_b( scheduled_node=scheduled_node, model_chunk=model_chunk, @@ -907,7 +914,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) - # print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}") + for h in self.wait_handles: + for hh in h: + hh.wait() # return loss & output if outputs is not None: outputs = merge_batch(outputs) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 3687cfb99..a88db87bc 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -381,7 +381,6 @@ class MixtralPipelineForwards: output_router_logits, use_cache, ) - hidden_states = layer_outputs[0] if use_cache: diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 54cd612f9..fab437c01 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -214,7 +214,6 @@ class MixtralPolicy(Policy): suffix="block_sparse_moe.gate", target_module=LinearWithGradAccum, kwargs={ - "gather_output": True, "fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv, }, @@ -414,7 +413,6 @@ class MixtralForCausalLMPolicy(MixtralPolicy): suffix="lm_head", target_module=LinearWithGradAccum, kwargs=dict( - gather_output=True, fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv, ), diff --git a/examples/language/mixtral/benchmark.py b/examples/language/mixtral/benchmark.py index 0334bd81c..dbffd0c2a 100644 --- a/examples/language/mixtral/benchmark.py +++ b/examples/language/mixtral/benchmark.py @@ -122,7 +122,7 @@ def main(): num_ckpt_layers_per_stage=[19, 19, 19, 13], ), "num_layers_per_stage": [19, 20, 20, 21], - # "pp_style": "interleaved", + "pp_style": "interleaved", } if args.custom_ckpt else {}