mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[pipeline]: add p2p fallback order and fix interleaved pp deadlock (#5214)
* fix: add fallback order option and update 1f1b * fix: fix deadlock comm in interleaved pp * test: modify p2p test
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -33,7 +31,7 @@ def check_p2p_communication():
|
||||
for obj in data:
|
||||
p2p.send_forward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_forward_recv_backward(data[i])
|
||||
recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 1:
|
||||
for obj in data:
|
||||
@@ -48,7 +46,7 @@ def check_p2p_communication():
|
||||
for obj in data:
|
||||
p2p.send_backward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_backward_recv_forward(data[i])
|
||||
recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 0:
|
||||
for obj in data:
|
||||
@@ -59,7 +57,6 @@ def check_p2p_communication():
|
||||
p2p.send_forward(data[-(i + 1)])
|
||||
assert recv_obj == data[i]
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
tensor_metadata = TensorMetadata(
|
||||
key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad
|
||||
)
|
||||
|
Reference in New Issue
Block a user