[pipeline] A more general _communicate in p2p (#5062)

* A more general _communicate

* feat: finish tree_flatten version p2p

* fix: update p2p api calls

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
Elsa Granger
2024-01-08 15:37:27 +08:00
committed by GitHub
parent 7bc6969ce6
commit d565df3821
4 changed files with 104 additions and 136 deletions

View File

@@ -4,7 +4,7 @@ import torch.distributed as dist
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
@@ -57,19 +57,15 @@ def check_p2p_communication():
p2p.send_forward(data[-(i + 1)])
assert recv_obj == data[i]
tensor_metadata = TensorMetadata(
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
)
comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata)
if rank == 0:
recv_obj = p2p.send_forward_recv_backward(
tensor,
send_metadata=False,
metadata_recv=comm_metadata,
metadata_recv=create_send_metadata(tensor),
)
assert recv_obj == tensor
elif rank == 1:
recv_obj = p2p.recv_forward(metadata_recv=comm_metadata)
recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
assert recv_obj == tensor
p2p.send_backward(tensor, send_metadata=False)