mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
Migrated project
This commit is contained in:
14
colossalai/communication/__init__.py
Normal file
14
colossalai/communication/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .collective import all_gather, reduce_scatter, scatter
|
||||
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward,
|
||||
send_backward, send_backward_recv_backward, send_forward_recv_backward,
|
||||
send_forward_backward_recv_forward_backward, recv_forward, recv_backward)
|
||||
from .ring import ring_forward
|
||||
from .utils import send_tensor_meta, recv_tensor_meta
|
||||
|
||||
__all__ = [
|
||||
'all_gather', 'reduce_scatter', 'scatter',
|
||||
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward',
|
||||
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward',
|
||||
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
|
||||
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta'
|
||||
]
|
84
colossalai/communication/collective.py
Normal file
84
colossalai/communication/collective.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, dim: int,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
"""Gathers all tensors from the parallel group and concatenates them in a
|
||||
specific dimension.
|
||||
|
||||
:param tensor: Tensor to be gathered
|
||||
:param dim: The dimension concatenating in
|
||||
:param parallel_mode: Parallel group mode used in this communication
|
||||
:type tensor: Tensor
|
||||
:type dim: int
|
||||
:type parallel_mode: ParallelMode
|
||||
:return: The tensor generated by all-gather
|
||||
:rtype: Tensor
|
||||
"""
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
temp = tensor.clone()
|
||||
shape = list(temp.shape)
|
||||
shape[dim] *= depth
|
||||
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device())
|
||||
out = list(torch.chunk(out, depth, dim=dim))
|
||||
out = [val.contiguous() for val in out]
|
||||
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode))
|
||||
out = torch.cat(out, dim=dim)
|
||||
return out
|
||||
|
||||
|
||||
def reduce_scatter(tensor: Tensor, dim: int,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
"""Reduces all tensors then scatters it in a specific dimension to all
|
||||
members in the parallel group.
|
||||
|
||||
:param tensor: Tensor to be reduced and scattered
|
||||
:param dim: The dimension scattering in
|
||||
:param parallel_mode: Parallel group mode used in this communication
|
||||
:type tensor: Tensor
|
||||
:type dim: int
|
||||
:type parallel_mode: ParallelMode
|
||||
:return: The tensor generated by reduce-scatter
|
||||
:rtype: Tensor
|
||||
"""
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
temp = list(torch.chunk(tensor, depth, dim=dim))
|
||||
temp = [val.contiguous() for val in temp]
|
||||
out = torch.empty(temp[0].shape,
|
||||
dtype=temp[0].dtype,
|
||||
device=get_current_device())
|
||||
dist.reduce_scatter(output=out,
|
||||
input_list=temp,
|
||||
group=gpc.get_group(parallel_mode))
|
||||
return out
|
||||
|
||||
|
||||
def scatter(tensor: Tensor, src: int, dim: int,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
"""Scatters in a specific dimension from source rank to all ranks in
|
||||
the parallel group.
|
||||
|
||||
:param tensor: Tensor to be scattered
|
||||
:param dim: The dimension scattering in
|
||||
:param parallel_mode: Parallel group mode used in this communication
|
||||
:type tensor: Tensor
|
||||
:type dim: int
|
||||
:type parallel_mode: ParallelMode
|
||||
:return: The tensor generated by scatter
|
||||
:rtype: Tensor
|
||||
"""
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
temp = tensor.clone()
|
||||
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
|
||||
rank = gpc.get_local_rank(parallel_mode)
|
||||
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
|
||||
return out
|
333
colossalai/communication/p2p.py
Normal file
333
colossalai/communication/p2p.py
Normal file
@@ -0,0 +1,333 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def _communicate(tensor_send_next=None,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=False,
|
||||
recv_next=False,
|
||||
recv_prev_shape=None,
|
||||
recv_next_shape=None,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
up_group=None,
|
||||
down_group=None,
|
||||
dtype=None):
|
||||
"""
|
||||
Adapted from megatron.p2p_communication.
|
||||
Communicate tensors between stages. Used as helper method in other
|
||||
communication methods that are used in pipeline schedule.
|
||||
Takes the following arguments:
|
||||
tensor_send_next: tensor to send to next rank (no tensor sent if
|
||||
set to None).
|
||||
tensor_send_prev: tensor to send to prev rank (no tensor sent if
|
||||
set to None).
|
||||
recv_prev: boolean for whether tensor should be received from
|
||||
previous rank.
|
||||
recv_next: boolean for whether tensor should be received from
|
||||
next rank.
|
||||
Returns:
|
||||
(tensor_recv_prev, tensor_recv_next)
|
||||
"""
|
||||
|
||||
# Create placeholder tensors for receive in forward and backward directions
|
||||
# if needed.
|
||||
tensor_recv_prev = None
|
||||
tensor_recv_next = None
|
||||
|
||||
if recv_prev:
|
||||
assert recv_prev_shape is not None
|
||||
tensor_recv_prev = torch.empty(recv_prev_shape,
|
||||
requires_grad=True,
|
||||
device=get_current_device(),
|
||||
dtype=dtype)
|
||||
if recv_next:
|
||||
assert recv_next_shape is not None
|
||||
tensor_recv_next = torch.empty(recv_next_shape,
|
||||
requires_grad=True,
|
||||
device=get_current_device(),
|
||||
dtype=dtype)
|
||||
|
||||
if tensor_send_prev is not None or recv_prev:
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
if up_group is None:
|
||||
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
|
||||
|
||||
if tensor_send_next is not None or recv_next:
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
if down_group is None:
|
||||
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
|
||||
|
||||
# rank = dist.get_rank()
|
||||
rank = gpc.get_global_rank()
|
||||
|
||||
ops = []
|
||||
if tensor_send_prev is not None:
|
||||
send_prev_op = dist.broadcast(tensor_send_prev,
|
||||
src=rank,
|
||||
group=up_group,
|
||||
async_op=True)
|
||||
ops.append(send_prev_op)
|
||||
if tensor_recv_prev is not None:
|
||||
recv_prev_op = dist.broadcast(tensor_recv_prev,
|
||||
src=prev_rank,
|
||||
group=up_group,
|
||||
async_op=True)
|
||||
ops.append(recv_prev_op)
|
||||
if tensor_recv_next is not None:
|
||||
recv_next_op = dist.broadcast(tensor_recv_next,
|
||||
src=next_rank,
|
||||
group=down_group,
|
||||
async_op=True)
|
||||
ops.append(recv_next_op)
|
||||
if tensor_send_next is not None:
|
||||
send_next_op = dist.broadcast(tensor_send_next,
|
||||
src=rank,
|
||||
group=down_group,
|
||||
async_op=True)
|
||||
ops.append(send_next_op)
|
||||
for req in ops:
|
||||
req.wait()
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
torch.cuda.synchronize()
|
||||
return tensor_recv_prev, tensor_recv_next
|
||||
|
||||
|
||||
def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
|
||||
"""Receives the input tensor from the previous member in pipeline.
|
||||
|
||||
:param input_tensor_shape: The shape of the tensor to be recieved
|
||||
:param prev_rank: The rank of the source of the tensor
|
||||
:param up_group: Communication group including the previous member in pipeline parallel group
|
||||
:type input_tensor_shape: torch.Size
|
||||
:type prev_rank: int, optional
|
||||
:type up_group: ProcessGroup, optional
|
||||
:return: The input tensor in forward step
|
||||
:rtype: Tensor
|
||||
"""
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor, _ = _communicate(recv_prev=True,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
up_group=up_group)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def recv_backward(output_grad_shape, next_rank=None, down_group=None):
|
||||
"""Receives the grad tensor from the next member in pipeline.
|
||||
|
||||
:param output_grad_shape: The shape of the tensor to be recieved
|
||||
:param next_rank: The rank of the source of the tensor
|
||||
:param down_group: Communication group including the next member in pipeline parallel group
|
||||
:type output_grad_shape: torch.Size
|
||||
:type next_rank: int, optional
|
||||
:type down_group: ProcessGroup, optional
|
||||
:return: The grad of output tensor in forward step
|
||||
:rtype: Tensor
|
||||
"""
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
_, output_tensor_grad = _communicate(recv_next=True,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
down_group=down_group)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward(output_tensor,
|
||||
next_rank=None,
|
||||
down_group=None):
|
||||
"""Sends the input tensor to the next member in pipeline.
|
||||
|
||||
:param output_tensor: Tensor to be sent
|
||||
:param next_rank: The rank of the recipient of the tensor
|
||||
:param down_group: Communication group including the next member in pipeline parallel group
|
||||
:type output_tensor: Tensor
|
||||
:type next_rank: int, optional
|
||||
:type down_group: ProcessGroup, optional
|
||||
"""
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
_communicate(tensor_send_next=output_tensor,
|
||||
next_rank=next_rank,
|
||||
down_group=down_group)
|
||||
|
||||
|
||||
def send_backward(input_tensor_grad,
|
||||
prev_rank=None,
|
||||
up_group=None):
|
||||
"""Sends the grad tensor to the previous member in pipeline.
|
||||
|
||||
:param input_tensor_grad: Tensor to be sent
|
||||
:param prev_rank: The rank of the recipient of the tensor
|
||||
:param up_group: Communication group including the previous member in pipeline parallel group
|
||||
:type input_tensor_grad: Tensor
|
||||
:type prev_rank: int, optional
|
||||
:type up_group: ProcessGroup, optional
|
||||
"""
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
_communicate(tensor_send_prev=input_tensor_grad,
|
||||
prev_rank=prev_rank,
|
||||
up_group=up_group)
|
||||
|
||||
|
||||
def send_forward_recv_backward(output_tensor,
|
||||
output_grad_shape,
|
||||
recv_next=True,
|
||||
next_rank=None,
|
||||
down_group=None):
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next member in pipeline, while recieves the grad tensor from the
|
||||
next member in pipeline.
|
||||
|
||||
:param output_tensor: Tensor to be sent
|
||||
:param output_grad_shape: The shape of the tensor to be recieved
|
||||
:type output_tensor: Tensor
|
||||
:type output_grad_shape: torch.Size
|
||||
:return: The grad of output tensor in forward step
|
||||
:rtype: Tensor
|
||||
"""
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
|
||||
recv_next=recv_next,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
down_group=down_group)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_backward_recv_forward(input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
up_group=None):
|
||||
"""Batched communication operation. Sends the grad tensor to the
|
||||
previous member in pipeline, while recieves the input tensor from the
|
||||
previous member in pipeline.
|
||||
|
||||
:param input_tensor_grad: Tensor to be sent
|
||||
:param input_tensor_shape: The shape of the tensor to be recieved
|
||||
:type input_tensor_grad: Tensor
|
||||
:type input_tensor_shape: torch.Size
|
||||
:return: The input tensor in forward step
|
||||
:rtype: Tensor
|
||||
"""
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
up_group=up_group)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_forward_recv_forward(output_tensor,
|
||||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
up_group=None,
|
||||
down_group=None):
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next member in pipeline, while recieves the input tensor from the
|
||||
previous member in pipeline.
|
||||
|
||||
:param output_tensor: Tensor to be sent
|
||||
:param input_tensor_shape: The shape of the tensor to be recieved
|
||||
:type output_tensor: Tensor
|
||||
:type input_tensor_shape: torch.Size
|
||||
:return: The input tensor in forward step
|
||||
:rtype: Tensor
|
||||
"""
|
||||
input_tensor, _ = _communicate(tensor_send_next=output_tensor,
|
||||
recv_prev=recv_prev,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
up_group=up_group,
|
||||
down_group=down_group)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_backward_recv_backward(input_tensor_grad,
|
||||
output_grad_shape,
|
||||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
up_group=None,
|
||||
down_group=None):
|
||||
"""Batched communication operation. Sends the grad tensor to the
|
||||
previous member in pipeline, while recieves the grad tensor from the
|
||||
next member in pipeline.
|
||||
|
||||
:param input_tensor_grad: Tensor to be sent
|
||||
:param output_grad_shape: The shape of the tensor to be recieved
|
||||
:type input_tensor_grad: Tensor
|
||||
:type output_grad_shape: torch.Size
|
||||
:return: The grad of output tensor in forward step
|
||||
:rtype: Tensor
|
||||
"""
|
||||
_, output_tensor_grad = _communicate(tensor_send_prev=input_tensor_grad,
|
||||
recv_next=recv_next,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
up_group=up_group,
|
||||
down_group=down_group)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward_backward_recv_forward_backward(output_tensor,
|
||||
input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
output_grad_shape,
|
||||
recv_prev=True,
|
||||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
up_group=None,
|
||||
down_group=None):
|
||||
"""Batched communication operation. Sends the input tensor to the next and
|
||||
the grad tensor to the previous, while recieves the grad tensor from the
|
||||
next and the input tensor from the previous.
|
||||
|
||||
:param output_tensor: Tensor sent to the next
|
||||
:param input_tensor_grad: Tensor sent to the previous
|
||||
:param input_tensor_shape: The shape of the tensor recieved from the previous
|
||||
:param output_grad_shape: The shape of the tensor recieved from the next
|
||||
:type output_tensor: Tensor
|
||||
:type input_tensor_grad: Tensor
|
||||
:type input_tensor_shape: torch.Size
|
||||
:type output_grad_shape: torch.Size
|
||||
:return: (the input tensor in forward step, the grad of output tensor in forward step)
|
||||
:rtype: (Tensor, Tensor)
|
||||
"""
|
||||
input_tensor, output_tensor_grad = _communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
up_group=up_group,
|
||||
down_group=down_group)
|
||||
return input_tensor, output_tensor_grad
|
54
colossalai/communication/ring.py
Normal file
54
colossalai/communication/ring.py
Normal file
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device, synchronize
|
||||
|
||||
|
||||
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode):
|
||||
"""Sends a tensor to the next member and recieves a tensor from the previous member.
|
||||
This function returns the recieved tensor from the previous member.
|
||||
|
||||
:param tensor_send_next: Tensor sent to next member
|
||||
:param parallel_mode: Parallel group mode used in this communication
|
||||
:type tensor_send_next: Tensor
|
||||
:type parallel_mode: ParallelMode
|
||||
:return: The tensor recieved from the previous
|
||||
:rtype: Tensor
|
||||
"""
|
||||
buffer_shape = tensor_send_next.size()
|
||||
|
||||
ops = []
|
||||
current_rank = gpc.get_global_rank()
|
||||
|
||||
tensor_recv_prev = torch.empty(buffer_shape,
|
||||
requires_grad=True,
|
||||
device=get_current_device(),
|
||||
dtype=tensor_send_next.dtype)
|
||||
|
||||
# send to next rank
|
||||
send_next_op = torch.distributed.P2POp(
|
||||
torch.distributed.isend, tensor_send_next,
|
||||
gpc.get_next_global_rank(parallel_mode))
|
||||
ops.append(send_next_op)
|
||||
|
||||
# receive from prev rank
|
||||
recv_prev_op = torch.distributed.P2POp(
|
||||
torch.distributed.irecv, tensor_recv_prev,
|
||||
gpc.get_prev_global_rank(parallel_mode))
|
||||
ops.append(recv_prev_op)
|
||||
|
||||
if current_rank % 2 == 0:
|
||||
ops = ops[::-1]
|
||||
|
||||
reqs = torch.distributed.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
synchronize()
|
||||
|
||||
return tensor_recv_prev
|
73
colossalai/communication/utils.py
Normal file
73
colossalai/communication/utils.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def send_tensor_meta(tensor, need_meta=True, down_group=None):
|
||||
"""Sends tensor meta information before sending a specific tensor.
|
||||
Since the recipient must know the shape of the tensor in p2p communications,
|
||||
meta information of the tensor should be sent before communications. This function
|
||||
synchronizes with :func:`recv_tensor_meta`.
|
||||
|
||||
:param tensor: Tensor to be sent
|
||||
:param need_meta: If False, meta information won't be sent
|
||||
:param down_group: Communication group including the next member in pipeline parallel group
|
||||
:type tensor: Tensor
|
||||
:type need_meta: bool, optional
|
||||
:type down_group: ProcessGroup, optional
|
||||
:return: False
|
||||
:rtype: bool
|
||||
"""
|
||||
if need_meta:
|
||||
rank = gpc.get_global_rank()
|
||||
|
||||
if down_group is None:
|
||||
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
|
||||
|
||||
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
|
||||
|
||||
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
|
||||
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
|
||||
|
||||
dist.broadcast(send_ndims, src=rank, group=down_group)
|
||||
dist.broadcast(send_shape, src=rank, group=down_group)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None):
|
||||
"""Recieves tensor meta information before recieving a specific tensor.
|
||||
Since the recipient must know the shape of the tensor in p2p communications,
|
||||
meta information of the tensor should be recieved before communications. This function
|
||||
synchronizes with :func:`send_tensor_meta`.
|
||||
|
||||
:param tensor_shape: The shape of the tensor to be recieved
|
||||
:param prev_rank: The rank of the source of the tensor
|
||||
:param up_group: Communication group including the previous member in pipeline parallel group
|
||||
:type tensor_shape: torch.Size
|
||||
:type prev_rank: int, optional
|
||||
:type up_group: ProcessGroup, optional
|
||||
:return: The shape of the tensor to be recieved
|
||||
:rtype: torch.Size
|
||||
"""
|
||||
if tensor_shape is None:
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
if up_group is None:
|
||||
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
|
||||
|
||||
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
|
||||
|
||||
recv_ndims = torch.empty((), **tensor_kwargs)
|
||||
dist.broadcast(recv_ndims, src=prev_rank, group=up_group)
|
||||
|
||||
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
|
||||
dist.broadcast(recv_shape, src=prev_rank, group=up_group)
|
||||
|
||||
tensor_shape = torch.Size(recv_shape)
|
||||
|
||||
return tensor_shape
|
Reference in New Issue
Block a user