[hotfix]different overflow status lead to communication stuck. (#1175)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [hotfix]fix some bugs caused by refactored schedule.

* [hotfix]different overflow statu llead to communication stuck.
This commit is contained in:
YuliangLiu0306
2022-06-27 09:53:57 +08:00
committed by GitHub
parent aa7bef73d4
commit e27645376d
3 changed files with 35 additions and 16 deletions

View File

@@ -57,10 +57,14 @@ def process_object_to_send(object_send, scatter_gather_tensors):
if send_split:
object_send = split_tensor_into_1d_equal_chunks(object_send)
return object_send
object_send_list = []
for tensor_send in object_send:
send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
if send_split:
tensor_send = split_tensor_into_1d_equal_chunks(tensor_send)
object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send))
object_send = tuple(object_send_list)
return object_send
@@ -161,15 +165,17 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non
if isinstance(tensor_recv_prev, torch.Tensor):
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
else:
for tensor_recv, tensor_shape in zip(tensor_recv_prev, recv_prev_shape):
tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_()
for index in range(len(tensor_recv_prev)):
tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view(
recv_prev_shape[index]).requires_grad_()
if recv_next and recv_next_split:
if isinstance(tensor_recv_next, torch.Tensor):
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
else:
for tensor_recv, tensor_shape in zip(tensor_recv_next, recv_next_shape):
tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_()
for index in range(len(tensor_recv_next)):
tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view(
recv_next_shape[index]).requires_grad_()
return tensor_recv_prev, tensor_recv_next