[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)

* test: add more p2p tests

* fix: remove send_forward_recv_forward as p2p op list need to use the same group

* fix: make send and receive atomic

* feat: update P2PComm fn

* feat: add metadata cache in 1f1b

* feat: add metadata cache in interleaved pp

* feat: modify is_xx_stage fn

* revert: add _broadcast_object_list

* feat: add interleaved pp in llama policy

* feat: set NCCL_BUFFSIZE in HybridParallelPlugin
This commit is contained in:
Wenhao Chen
2023-12-22 10:44:00 +08:00
committed by GitHub
parent af952673f7
commit 4fa689fca1
15 changed files with 728 additions and 446 deletions

View File

@@ -309,11 +309,11 @@ class BertPolicy(Policy):
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
if stage_manager.is_first_stage(-1):
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage(-1):
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.pooler)
else:
@@ -370,7 +370,7 @@ class BertForPreTrainingPolicy(BertPolicy):
"""Get pipeline layers for current stage"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
return held_layers
@@ -409,7 +409,7 @@ class BertLMHeadModelPolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
return held_layers
@@ -447,7 +447,7 @@ class BertForMaskedLMPolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
return held_layers
@@ -499,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(None if not stage_manager.is_interleave else -1):
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
@@ -543,7 +543,7 @@ class BertForTokenClassificationPolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
@@ -574,7 +574,7 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.cls)
return held_layers
@@ -617,7 +617,7 @@ class BertForMultipleChoicePolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
@@ -647,7 +647,7 @@ class BertForQuestionAnsweringPolicy(BertPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.qa_outputs)
return held_layers