diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py index 234791e32..67105c41b 100644 --- a/colossalai/communication/utils.py +++ b/colossalai/communication/utils.py @@ -7,7 +7,7 @@ from colossalai.utils import get_current_device def send_tensor_meta(tensor, need_meta=True, next_rank=None): - """Sends tensor meta information before sending a specific tensor. + """Sends tensor meta information before sending a specific tensor. Since the recipient must know the shape of the tensor in p2p communications, meta information of the tensor should be sent before communications. This function synchronizes with :func:`recv_tensor_meta`. @@ -36,7 +36,7 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None): def recv_tensor_meta(tensor_shape, prev_rank=None): - """Recieves tensor meta information before recieving a specific tensor. + """Recieves tensor meta information before recieving a specific tensor. Since the recipient must know the shape of the tensor in p2p communications, meta information of the tensor should be recieved before communications. This function synchronizes with :func:`send_tensor_meta`. @@ -104,6 +104,6 @@ def gather_split_1d_tensor(tensor): gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) - chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)] + chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) return gathered