mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)
* test: add more p2p tests * fix: remove send_forward_recv_forward as p2p op list need to use the same group * fix: make send and receive atomic * feat: update P2PComm fn * feat: add metadata cache in 1f1b * feat: add metadata cache in interleaved pp * feat: modify is_xx_stage fn * revert: add _broadcast_object_list * feat: add interleaved pp in llama policy * feat: set NCCL_BUFFSIZE in HybridParallelPlugin
This commit is contained in:
@@ -203,7 +203,7 @@ def check_output_hidden_state(
|
||||
):
|
||||
org_hidden_state = org_output.last_hidden_state
|
||||
|
||||
if stage_manager and stage_manager.is_last_stage():
|
||||
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
|
||||
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
||||
else:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
@@ -229,6 +229,10 @@ def check_weight(
|
||||
org_weight = getattr_(org_model, suffix).weight
|
||||
sharded_weight = getattr_(sharded_model, suffix).weight
|
||||
|
||||
# skip if layer is not held by this process
|
||||
if sharded_weight is None:
|
||||
continue
|
||||
|
||||
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
|
||||
sharded_weight_list = [
|
||||
torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))
|
||||
|
Reference in New Issue
Block a user