mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -14,21 +14,21 @@ from .ring import ring_forward
|
||||
from .utils import recv_obj_meta, send_obj_meta
|
||||
|
||||
__all__ = [
|
||||
'all_gather',
|
||||
'reduce_scatter',
|
||||
'all_reduce',
|
||||
'broadcast',
|
||||
'reduce',
|
||||
'send_forward',
|
||||
'send_forward_recv_forward',
|
||||
'send_forward_backward_recv_forward_backward',
|
||||
'send_backward',
|
||||
'send_backward_recv_backward',
|
||||
'send_backward_recv_forward',
|
||||
'send_forward_recv_backward',
|
||||
'recv_backward',
|
||||
'recv_forward',
|
||||
'ring_forward',
|
||||
'send_obj_meta',
|
||||
'recv_obj_meta',
|
||||
"all_gather",
|
||||
"reduce_scatter",
|
||||
"all_reduce",
|
||||
"broadcast",
|
||||
"reduce",
|
||||
"send_forward",
|
||||
"send_forward_recv_forward",
|
||||
"send_forward_backward_recv_forward_backward",
|
||||
"send_backward",
|
||||
"send_backward_recv_backward",
|
||||
"send_backward_recv_forward",
|
||||
"send_forward_recv_backward",
|
||||
"recv_backward",
|
||||
"recv_forward",
|
||||
"ring_forward",
|
||||
"send_obj_meta",
|
||||
"recv_obj_meta",
|
||||
]
|
||||
|
@@ -9,10 +9,10 @@ from torch.distributed import ReduceOp
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
_all_gather_func = dist._all_gather_base \
|
||||
if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
|
||||
_reduce_scatter_func = dist._reduce_scatter_base \
|
||||
if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor
|
||||
_all_gather_func = dist._all_gather_base if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
|
||||
_reduce_scatter_func = (
|
||||
dist._reduce_scatter_base if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor
|
||||
)
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
|
||||
@@ -50,11 +50,9 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
|
||||
return out
|
||||
|
||||
|
||||
def reduce_scatter(tensor: Tensor,
|
||||
dim: int,
|
||||
parallel_mode: ParallelMode,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
async_op: bool = False) -> Tensor:
|
||||
def reduce_scatter(
|
||||
tensor: Tensor, dim: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False
|
||||
) -> Tensor:
|
||||
r"""Reduces all tensors then scatters it in a specific dimension to all
|
||||
members in the parallel group.
|
||||
|
||||
@@ -93,10 +91,9 @@ def reduce_scatter(tensor: Tensor,
|
||||
return out
|
||||
|
||||
|
||||
def all_reduce(tensor: Tensor,
|
||||
parallel_mode: ParallelMode,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
async_op: bool = False) -> Tensor:
|
||||
def all_reduce(
|
||||
tensor: Tensor, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False
|
||||
) -> Tensor:
|
||||
r"""Reduces the tensor data across whole parallel group in such a way that all get the final result.
|
||||
|
||||
Note:
|
||||
@@ -201,16 +198,17 @@ def scatter_object_list(scatter_object_output_list, scatter_object_input_list, s
|
||||
if dist.distributed_c10d._rank_not_in_group(group):
|
||||
return
|
||||
|
||||
if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1):
|
||||
if not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1:
|
||||
raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.")
|
||||
|
||||
# set tensor device to cuda if backend is nccl
|
||||
device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu")
|
||||
device = torch.cuda.current_device() if dist.get_backend(group) == "nccl" else torch.device("cpu")
|
||||
|
||||
my_rank = dist.get_rank() # use global rank
|
||||
my_rank = dist.get_rank() # use global rank
|
||||
if my_rank == src:
|
||||
tensor_list, tensor_sizes = zip(
|
||||
*[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list])
|
||||
*[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list]
|
||||
)
|
||||
tensor_list = list(map(lambda x: x.to(device), tensor_list))
|
||||
tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes))
|
||||
|
||||
|
@@ -82,16 +82,18 @@ def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
|
||||
ops_queue.append(op_to_add)
|
||||
|
||||
|
||||
def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
|
||||
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
|
||||
recv_prev: bool = False,
|
||||
recv_next: bool = False,
|
||||
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||
prev_rank: int = None,
|
||||
next_rank: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
scatter_gather_tensors: bool = False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
def _communicate(
|
||||
object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
|
||||
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
|
||||
recv_prev: bool = False,
|
||||
recv_next: bool = False,
|
||||
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||
prev_rank: int = None,
|
||||
next_rank: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
scatter_gather_tensors: bool = False,
|
||||
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
"""
|
||||
Adapted from megatron.p2p_communication.
|
||||
Communicate tensors between stages. Used as helper method in other
|
||||
@@ -123,13 +125,15 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non
|
||||
|
||||
if recv_prev:
|
||||
assert recv_prev_shape is not None
|
||||
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(recv_prev_shape, dtype,
|
||||
scatter_gather_tensors)
|
||||
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
|
||||
recv_prev_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
|
||||
if recv_next:
|
||||
assert recv_next_shape is not None
|
||||
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(recv_next_shape, dtype,
|
||||
scatter_gather_tensors)
|
||||
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
|
||||
recv_next_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
|
||||
if object_send_prev is not None or recv_prev:
|
||||
if prev_rank is None:
|
||||
@@ -170,24 +174,25 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non
|
||||
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
||||
else:
|
||||
for index in range(len(tensor_recv_prev)):
|
||||
tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view(
|
||||
recv_prev_shape[index]).requires_grad_()
|
||||
tensor_recv_prev[index] = (
|
||||
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
|
||||
)
|
||||
|
||||
if recv_next and recv_next_split:
|
||||
if isinstance(tensor_recv_next, torch.Tensor):
|
||||
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
||||
else:
|
||||
for index in range(len(tensor_recv_next)):
|
||||
tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view(
|
||||
recv_next_shape[index]).requires_grad_()
|
||||
tensor_recv_next[index] = (
|
||||
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
|
||||
)
|
||||
|
||||
return tensor_recv_prev, tensor_recv_next
|
||||
|
||||
|
||||
def recv_forward(input_tensor_shape,
|
||||
prev_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
def recv_forward(
|
||||
input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
|
||||
Args:
|
||||
@@ -200,18 +205,19 @@ def recv_forward(input_tensor_shape,
|
||||
if gpc.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor, _ = _communicate(recv_prev=True,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
input_tensor, _ = _communicate(
|
||||
recv_prev=True,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def recv_backward(output_grad_shape,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
def recv_backward(
|
||||
output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
||||
Args:
|
||||
@@ -224,11 +230,13 @@ def recv_backward(output_grad_shape,
|
||||
if gpc.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
_, output_tensor_grad = _communicate(recv_next=True,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
_, output_tensor_grad = _communicate(
|
||||
recv_next=True,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
@@ -251,17 +259,14 @@ def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=Fals
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
_communicate(object_send_prev=input_tensor_grad,
|
||||
prev_rank=prev_rank,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
_communicate(
|
||||
object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors
|
||||
)
|
||||
|
||||
|
||||
def send_forward_recv_backward(output_tensor,
|
||||
output_grad_shape,
|
||||
recv_next=True,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
def send_forward_recv_backward(
|
||||
output_tensor, output_grad_shape, recv_next=True, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next stage in pipeline, while receives the gradient tensor from the
|
||||
next stage in pipeline as the input gradient tensor of this stage.
|
||||
@@ -276,21 +281,25 @@ def send_forward_recv_backward(output_tensor,
|
||||
if gpc.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
_, output_tensor_grad = _communicate(object_send_next=output_tensor,
|
||||
recv_next=recv_next,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
_, output_tensor_grad = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
recv_next=recv_next,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_backward_recv_forward(input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
def send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the gradient tensor to the
|
||||
previous stage in pipeline, while receives the output tensor from the
|
||||
previous stage in pipeline as the input of this stage.
|
||||
@@ -305,22 +314,26 @@ def send_backward_recv_forward(input_tensor_grad,
|
||||
if gpc.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor, _ = _communicate(object_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
input_tensor, _ = _communicate(
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_forward_recv_forward(output_tensor,
|
||||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
def send_forward_recv_forward(
|
||||
output_tensor,
|
||||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next stage in pipeline, while receives the output tensor from the
|
||||
previous stage in pipeline as the input of this stage.
|
||||
@@ -332,23 +345,27 @@ def send_forward_recv_forward(output_tensor,
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
|
||||
"""
|
||||
input_tensor, _ = _communicate(object_send_next=output_tensor,
|
||||
recv_prev=recv_prev,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
input_tensor, _ = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
recv_prev=recv_prev,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_backward_recv_backward(input_tensor_grad,
|
||||
output_grad_shape,
|
||||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
def send_backward_recv_backward(
|
||||
input_tensor_grad,
|
||||
output_grad_shape,
|
||||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the gradient tensor to the
|
||||
previous stage in pipeline, while receives the gradient tensor from the
|
||||
next member in pipeline as the input of this stage.
|
||||
@@ -360,27 +377,30 @@ def send_backward_recv_backward(input_tensor_grad,
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
|
||||
"""
|
||||
_, output_tensor_grad = _communicate(object_send_prev=input_tensor_grad,
|
||||
recv_next=recv_next,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
_, output_tensor_grad = _communicate(
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_next=recv_next,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward_backward_recv_forward_backward(
|
||||
output_tensor,
|
||||
input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
output_grad_shape,
|
||||
recv_prev=True,
|
||||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
output_tensor,
|
||||
input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
output_grad_shape,
|
||||
recv_prev=True,
|
||||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
|
||||
the gradient tensor to the previous stage, while receives the input gradient tensor from the
|
||||
next stage and the input tensor from the previous stage.
|
||||
@@ -394,14 +414,16 @@ def send_forward_backward_recv_forward_backward(
|
||||
Returns:
|
||||
Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor)
|
||||
"""
|
||||
input_tensor, output_tensor_grad = _communicate(object_send_next=output_tensor,
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
input_tensor, output_tensor_grad = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor, output_tensor_grad
|
||||
|
@@ -62,10 +62,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
||||
Any: object after unpickled
|
||||
"""
|
||||
buf = tensor.numpy().tobytes()[:tensor_size]
|
||||
if b'cuda' in buf:
|
||||
if b"cuda" in buf:
|
||||
buf_array = bytearray(buf)
|
||||
device_index = torch.cuda.current_device()
|
||||
buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index
|
||||
buf_array[buf_array.find(b"cuda") + 5] = 48 + device_index
|
||||
buf = bytes(buf_array)
|
||||
|
||||
io_bytes = io.BytesIO(buf)
|
||||
@@ -123,8 +123,8 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No
|
||||
if local_rank == src:
|
||||
object_tensor = torch.cat(tensor_list)
|
||||
else:
|
||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
|
||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
|
||||
@@ -138,7 +138,7 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No
|
||||
|
||||
if local_rank != src:
|
||||
for i, obj_size in enumerate(object_sizes_tensor):
|
||||
obj_view = object_tensor[offset:offset + obj_size]
|
||||
obj_view = object_tensor[offset : offset + obj_size]
|
||||
obj_view = obj_view.type(torch.uint8)
|
||||
if obj_view.device != torch.device("cpu"):
|
||||
obj_view = obj_view.cpu()
|
||||
@@ -147,8 +147,10 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No
|
||||
unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size)
|
||||
|
||||
# unconsistence in device
|
||||
if isinstance(unpickle_object,
|
||||
torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
|
||||
if (
|
||||
isinstance(unpickle_object, torch.Tensor)
|
||||
and unpickle_object.device.index != torch.cuda.current_device()
|
||||
):
|
||||
unpickle_object = unpickle_object.cuda()
|
||||
|
||||
object_list[i] = unpickle_object
|
||||
|
@@ -28,19 +28,20 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
|
||||
ops = []
|
||||
current_rank = gpc.get_global_rank()
|
||||
|
||||
tensor_recv_prev = torch.empty(buffer_shape,
|
||||
requires_grad=True,
|
||||
device=get_current_device(),
|
||||
dtype=tensor_send_next.dtype)
|
||||
tensor_recv_prev = torch.empty(
|
||||
buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype
|
||||
)
|
||||
|
||||
# send to next rank
|
||||
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
|
||||
gpc.get_next_global_rank(parallel_mode))
|
||||
send_next_op = torch.distributed.P2POp(
|
||||
torch.distributed.isend, tensor_send_next, gpc.get_next_global_rank(parallel_mode)
|
||||
)
|
||||
ops.append(send_next_op)
|
||||
|
||||
# receive from prev rank
|
||||
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
|
||||
gpc.get_prev_global_rank(parallel_mode))
|
||||
recv_prev_op = torch.distributed.P2POp(
|
||||
torch.distributed.irecv, tensor_recv_prev, gpc.get_prev_global_rank(parallel_mode)
|
||||
)
|
||||
ops.append(recv_prev_op)
|
||||
|
||||
if current_rank % 2 == 0:
|
||||
|
@@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
if isinstance(obj, torch.Tensor):
|
||||
send_obj_nums = torch.tensor(1, **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
@@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
recv_obj_nums = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_obj_nums, prev_rank)
|
||||
if recv_obj_nums.item() == 1:
|
||||
@@ -122,6 +122,6 @@ def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
numel = torch.numel(tensor)
|
||||
numel_gathered = world_size * numel
|
||||
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
|
||||
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
|
||||
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
|
||||
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
return gathered
|
||||
|
Reference in New Issue
Block a user