mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)
* test: add more p2p tests * fix: remove send_forward_recv_forward as p2p op list need to use the same group * fix: make send and receive atomic * feat: update P2PComm fn * feat: add metadata cache in 1f1b * feat: add metadata cache in interleaved pp * feat: modify is_xx_stage fn * revert: add _broadcast_object_list * feat: add interleaved pp in llama policy * feat: set NCCL_BUFFSIZE in HybridParallelPlugin
This commit is contained in:
@@ -309,11 +309,11 @@ class BertPolicy(Policy):
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
num_stages=stage_manager.num_stages,
|
||||
)
|
||||
if stage_manager.is_first_stage(-1):
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embeddings)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage(-1):
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(module.pooler)
|
||||
|
||||
else:
|
||||
@@ -370,7 +370,7 @@ class BertForPreTrainingPolicy(BertPolicy):
|
||||
"""Get pipeline layers for current stage"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.cls)
|
||||
|
||||
return held_layers
|
||||
@@ -409,7 +409,7 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.cls)
|
||||
return held_layers
|
||||
|
||||
@@ -447,7 +447,7 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.cls)
|
||||
return held_layers
|
||||
|
||||
@@ -499,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage(None if not stage_manager.is_interleave else -1):
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
@@ -543,7 +543,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
||||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
@@ -574,7 +574,7 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.cls)
|
||||
return held_layers
|
||||
|
||||
@@ -617,7 +617,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
@@ -647,7 +647,7 @@ class BertForQuestionAnsweringPolicy(BertPolicy):
|
||||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
return held_layers
|
||||
|
||||
|
Reference in New Issue
Block a user