mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)
* test: add more p2p tests * fix: remove send_forward_recv_forward as p2p op list need to use the same group * fix: make send and receive atomic * feat: update P2PComm fn * feat: add metadata cache in 1f1b * feat: add metadata cache in interleaved pp * feat: modify is_xx_stage fn * revert: add _broadcast_object_list * feat: add interleaved pp in llama policy * feat: set NCCL_BUFFSIZE in HybridParallelPlugin
This commit is contained in:
@@ -1,47 +1,80 @@
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
WORLD_SIZE = 2
|
||||
|
||||
|
||||
def check_p2p_communication():
|
||||
pg_mesh = ProcessGroupMesh(2)
|
||||
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, 0)
|
||||
p2p = PipelineP2PCommunication(stage_manager)
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
||||
tensor = torch.ones(1, device=get_current_device())
|
||||
data = [
|
||||
"tensor",
|
||||
tensor,
|
||||
[tensor],
|
||||
{"tensor": tensor},
|
||||
]
|
||||
|
||||
if rank == 0:
|
||||
p2p.send_forward(tensor)
|
||||
p2p.send_forward([tensor])
|
||||
p2p.send_forward({"tensor": tensor})
|
||||
else:
|
||||
obj = p2p.recv_forward()
|
||||
assert torch.equal(obj, tensor)
|
||||
obj = p2p.recv_forward()
|
||||
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
|
||||
obj = p2p.recv_forward()
|
||||
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor)
|
||||
for obj in data:
|
||||
p2p.send_forward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_forward_recv_backward(data[i])
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 1:
|
||||
for obj in data:
|
||||
recv_obj = p2p.recv_forward()
|
||||
assert recv_obj == obj
|
||||
for i in range(len(data)):
|
||||
p2p.send_backward(data[-(i + 1)])
|
||||
recv_obj = p2p.recv_forward()
|
||||
assert recv_obj == data[i]
|
||||
|
||||
if rank == 1:
|
||||
p2p.send_backward(tensor)
|
||||
p2p.send_backward([tensor])
|
||||
p2p.send_backward({"tensor": tensor})
|
||||
else:
|
||||
obj = p2p.recv_backward()
|
||||
assert torch.equal(obj, tensor)
|
||||
obj = p2p.recv_backward()
|
||||
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
|
||||
obj = p2p.recv_backward()
|
||||
assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor)
|
||||
for obj in data:
|
||||
p2p.send_backward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_backward_recv_forward(data[i])
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 0:
|
||||
for obj in data:
|
||||
recv_obj = p2p.recv_backward()
|
||||
assert recv_obj == obj
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.recv_backward()
|
||||
p2p.send_forward(data[-(i + 1)])
|
||||
assert recv_obj == data[i]
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
tensor_metadata = TensorMetadata(
|
||||
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
|
||||
)
|
||||
comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata)
|
||||
if rank == 0:
|
||||
recv_obj = p2p.send_forward_recv_backward(
|
||||
tensor,
|
||||
send_metadata=False,
|
||||
metadata_recv=comm_metadata,
|
||||
)
|
||||
assert recv_obj == tensor
|
||||
elif rank == 1:
|
||||
recv_obj = p2p.recv_forward(metadata_recv=comm_metadata)
|
||||
assert recv_obj == tensor
|
||||
p2p.send_backward(tensor, send_metadata=False)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
@@ -52,7 +85,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pipeline_p2p():
|
||||
spawn(run_dist, 2)
|
||||
spawn(run_dist, WORLD_SIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user