mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 11:44:03 +00:00
[hotfix]different overflow status lead to communication stuck. (#1175)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [hotfix]fix some bugs caused by refactored schedule.
* [hotfix]different overflow statu llead to communication stuck.
This commit is contained in:
@@ -57,10 +57,14 @@ def process_object_to_send(object_send, scatter_gather_tensors):
|
||||
if send_split:
|
||||
object_send = split_tensor_into_1d_equal_chunks(object_send)
|
||||
return object_send
|
||||
|
||||
object_send_list = []
|
||||
for tensor_send in object_send:
|
||||
send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
|
||||
if send_split:
|
||||
tensor_send = split_tensor_into_1d_equal_chunks(tensor_send)
|
||||
object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send))
|
||||
object_send = tuple(object_send_list)
|
||||
|
||||
return object_send
|
||||
|
||||
|
||||
@@ -161,15 +165,17 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non
|
||||
if isinstance(tensor_recv_prev, torch.Tensor):
|
||||
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
||||
else:
|
||||
for tensor_recv, tensor_shape in zip(tensor_recv_prev, recv_prev_shape):
|
||||
tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_()
|
||||
for index in range(len(tensor_recv_prev)):
|
||||
tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view(
|
||||
recv_prev_shape[index]).requires_grad_()
|
||||
|
||||
if recv_next and recv_next_split:
|
||||
if isinstance(tensor_recv_next, torch.Tensor):
|
||||
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
||||
else:
|
||||
for tensor_recv, tensor_shape in zip(tensor_recv_next, recv_next_shape):
|
||||
tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_()
|
||||
for index in range(len(tensor_recv_next)):
|
||||
tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view(
|
||||
recv_next_shape[index]).requires_grad_()
|
||||
|
||||
return tensor_recv_prev, tensor_recv_next
|
||||
|
||||
|
Reference in New Issue
Block a user