diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py index 3ac1bffc1..7554910e4 100644 --- a/colossalai/communication/utils.py +++ b/colossalai/communication/utils.py @@ -77,9 +77,7 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D) end_index = start_index + partition_size if new_buffer: - data = torch.empty(partition_size, dtype=tensor.dtype, - device=torch.cuda.current_device(), - requires_grad=False) + data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) data.copy_(tensor.view(-1)[start_index:end_index]) else: data = tensor.view(-1)[start_index:end_index] @@ -97,9 +95,7 @@ def gather_split_1d_tensor(tensor): world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) numel = torch.numel(tensor) numel_gathered = world_size * numel - gathered = torch.empty(numel_gathered, dtype=tensor.dtype, - device=torch.cuda.current_device(), - requires_grad=False) + gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) return gathered