[pipeline]: add p2p fallback order and fix interleaved pp deadlock (#5214)

* fix: add fallback order option and update 1f1b

* fix: fix deadlock comm in interleaved pp

* test: modify p2p test
This commit is contained in:
Wenhao Chen
2024-01-03 11:34:49 +08:00
committed by GitHub
parent 3c0d82b19b
commit d799a3088f
5 changed files with 269 additions and 136 deletions

View File

@@ -1,5 +1,3 @@
import warnings
import pytest
import torch
import torch.distributed as dist
@@ -33,7 +31,7 @@ def check_p2p_communication():
for obj in data:
p2p.send_forward(obj)
for i in range(len(data)):
recv_obj = p2p.send_forward_recv_backward(data[i])
recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
assert recv_obj == data[-(i + 1)]
elif rank == 1:
for obj in data:
@@ -48,7 +46,7 @@ def check_p2p_communication():
for obj in data:
p2p.send_backward(obj)
for i in range(len(data)):
recv_obj = p2p.send_backward_recv_forward(data[i])
recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
assert recv_obj == data[-(i + 1)]
elif rank == 0:
for obj in data:
@@ -59,7 +57,6 @@ def check_p2p_communication():
p2p.send_forward(data[-(i + 1)])
assert recv_obj == data[i]
warnings.filterwarnings("error")
tensor_metadata = TensorMetadata(
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
)