[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)

* test: add more p2p tests

* fix: remove send_forward_recv_forward as p2p op list need to use the same group

* fix: make send and receive atomic

* feat: update P2PComm fn

* feat: add metadata cache in 1f1b

* feat: add metadata cache in interleaved pp

* feat: modify is_xx_stage fn

* revert: add _broadcast_object_list

* feat: add interleaved pp in llama policy

* feat: set NCCL_BUFFSIZE in HybridParallelPlugin
This commit is contained in:
Wenhao Chen
2023-12-22 10:44:00 +08:00
committed by GitHub
parent af952673f7
commit 4fa689fca1
15 changed files with 728 additions and 446 deletions

View File

@@ -203,7 +203,7 @@ def check_output_hidden_state(
):
org_hidden_state = org_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
else:
sharded_hidden_state = sharded_output.last_hidden_state
@@ -229,6 +229,10 @@ def check_weight(
org_weight = getattr_(org_model, suffix).weight
sharded_weight = getattr_(sharded_model, suffix).weight
# skip if layer is not held by this process
if sharded_weight is None:
continue
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [
torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))