mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-25 01:40:08 +00:00 
			
		
		
		
	improved allgather & reducescatter for 3d
This commit is contained in:
		| @@ -3,12 +3,17 @@ | ||||
|  | ||||
| import torch | ||||
| import torch.distributed as dist | ||||
| from torch.distributed import ReduceOp | ||||
| from torch import Tensor | ||||
| from torch.distributed import ReduceOp | ||||
|  | ||||
| from colossalai.context import ParallelMode | ||||
| from colossalai.core import global_context as gpc | ||||
|  | ||||
| _all_gather_func = dist._all_gather_base \ | ||||
|     if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor | ||||
| _reduce_scatter_func = dist._reduce_scatter_base \ | ||||
|     if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor | ||||
|  | ||||
|  | ||||
| def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: | ||||
|     r"""Gathers all tensors from the parallel group and concatenates them in a | ||||
| @@ -33,17 +38,12 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: | ||||
|         out = tensor | ||||
|         work = None | ||||
|     else: | ||||
|         shape = list(tensor.shape) | ||||
|         shape[0], shape[dim] = shape[dim], shape[0] | ||||
|         shape[0] *= depth | ||||
|         out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device) | ||||
|         temp = list(torch.chunk(out, depth, dim=0)) | ||||
|         tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous() | ||||
|         out_shape = (tensor_in.shape[0] * depth,) + tensor_in.shape[1:] | ||||
|         tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device) | ||||
|         group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) | ||||
|         work = dist.all_gather(tensor_list=temp, | ||||
|                                tensor=tensor.transpose(0, dim).contiguous(), | ||||
|                                group=group, | ||||
|                                async_op=async_op) | ||||
|         out = torch.transpose(out, 0, dim) | ||||
|         work = _all_gather_func(tensor_out, tensor_in, group=group, async_op=async_op) | ||||
|         out = tensor_out if dim == 0 else tensor_out.transpose(0, dim) | ||||
|     if async_op: | ||||
|         return out, work | ||||
|     else: | ||||
| @@ -81,10 +81,12 @@ def reduce_scatter(tensor: Tensor, | ||||
|         out = tensor | ||||
|         work = None | ||||
|     else: | ||||
|         temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim))) | ||||
|         out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device) | ||||
|         tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous() | ||||
|         out_shape = (tensor_in.shape[0] // depth,) + tensor_in.shape[1:] | ||||
|         tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device) | ||||
|         group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) | ||||
|         work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op) | ||||
|         work = _reduce_scatter_func(tensor_out, tensor_in, op=op, group=group, async_op=async_op) | ||||
|         out = tensor_out if dim == 0 else tensor_out.transpose(0, dim) | ||||
|     if async_op: | ||||
|         return out, work | ||||
|     else: | ||||
| @@ -193,7 +195,8 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = | ||||
|  | ||||
|  | ||||
| def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None: | ||||
|     r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues | ||||
|     r"""Modified from `torch.distributed.scatter_object_list | ||||
|     <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues | ||||
|     """ | ||||
|     if dist.distributed_c10d._rank_not_in_group(group): | ||||
|         return | ||||
|   | ||||
| @@ -34,7 +34,7 @@ class _Linear3D(torch.autograd.Function): | ||||
|         ctx.output_parallel_mode = output_parallel_mode | ||||
|  | ||||
|         input_ = all_gather(input_, 0, input_parallel_mode) | ||||
|         weight = all_gather(weight, -1, weight_parallel_mode) | ||||
|         weight = all_gather(weight, 0, weight_parallel_mode) | ||||
|         ctx.save_for_backward(input_, weight) | ||||
|  | ||||
|         output = torch.matmul(input_, weight) | ||||
| @@ -53,7 +53,7 @@ class _Linear3D(torch.autograd.Function): | ||||
|  | ||||
|         weight_grad = torch.matmul( | ||||
|             input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) | ||||
|         weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True) | ||||
|         weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True) | ||||
|         weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) | ||||
|  | ||||
|         input_op.wait() | ||||
| @@ -205,7 +205,7 @@ class _VocabParallelClassifier3D(torch.autograd.Function): | ||||
|         ctx.weight_id = weight_id | ||||
|  | ||||
|         input_ = all_gather(input_, 0, input_parallel_mode) | ||||
|         weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode) | ||||
|         weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1) | ||||
|         ctx.save_for_backward(input_, weight) | ||||
|  | ||||
|         output = torch.matmul(input_, weight) | ||||
|   | ||||
| @@ -196,8 +196,8 @@ class Linear3D(ParallelLayer): | ||||
|         self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) | ||||
|         self.depth = get_depth_from_env() | ||||
|         self.skip_bias_add = skip_bias_add | ||||
|         self.in_features_per_partition = divide(in_features, self.depth) | ||||
|         self.out_features_per_partition = divide(out_features, self.depth**2) | ||||
|         self.in_features_per_partition = divide(in_features, self.depth**2) | ||||
|         self.out_features_per_partition = divide(out_features, self.depth) | ||||
|         self.bias_features_per_partition = divide(out_features, self.depth) | ||||
|  | ||||
|         self.weight = Parameter( | ||||
| @@ -287,7 +287,7 @@ class Linear3D(ParallelLayer): | ||||
|             local_state, | ||||
|             self.weight_parallel_mode, | ||||
|             dims={ | ||||
|                 weight_key: -1, | ||||
|                 weight_key: 0, | ||||
|                 bias_key: 0 | ||||
|             }, | ||||
|             partition_states={ | ||||
| @@ -310,7 +310,7 @@ class Linear3D(ParallelLayer): | ||||
|             local_state, | ||||
|             self.weight_parallel_mode, | ||||
|             dims={ | ||||
|                 weight_key: -1, | ||||
|                 weight_key: 0, | ||||
|                 bias_key: 0 | ||||
|             }, | ||||
|             partition_states={ | ||||
|   | ||||
| @@ -4,12 +4,23 @@ | ||||
| import time | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D | ||||
| from colossalai.core import global_context | ||||
| from colossalai.logging import get_dist_logger | ||||
| from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, | ||||
|                            VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D, | ||||
|                            VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D) | ||||
| from colossalai.nn import ( | ||||
|     Classifier3D, | ||||
|     CrossEntropyLoss3D, | ||||
|     Embedding3D, | ||||
|     LayerNorm3D, | ||||
|     Linear3D, | ||||
|     PatchEmbedding3D, | ||||
|     VanillaClassifier, | ||||
|     VanillaPatchEmbedding, | ||||
|     VocabParallelClassifier3D, | ||||
|     VocabParallelCrossEntropyLoss3D, | ||||
|     VocabParallelEmbedding3D, | ||||
| ) | ||||
| from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env | ||||
| from colossalai.utils import get_current_device, print_rank_0 | ||||
|  | ||||
| @@ -40,7 +51,7 @@ def check_linear(): | ||||
|     torch.distributed.broadcast(weight_master, src=0) | ||||
|     weight = torch.chunk(weight_master, DEPTH, dim=0)[k] | ||||
|     weight = torch.chunk(weight, DEPTH, dim=-1)[j] | ||||
|     weight = torch.chunk(weight, DEPTH, dim=-1)[i] | ||||
|     weight = torch.chunk(weight, DEPTH, dim=0)[i] | ||||
|     layer.weight.data.copy_(weight) | ||||
|     bias_master = layer_master.bias.data | ||||
|     torch.distributed.broadcast(bias_master, src=0) | ||||
| @@ -93,7 +104,7 @@ def check_linear(): | ||||
|     B_grad = layer_master.weight.grad.transpose(0, 1) | ||||
|     B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] | ||||
|     B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] | ||||
|     B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] | ||||
|     B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] | ||||
|     logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) | ||||
|  | ||||
|     bias_grad = layer_master.bias.grad | ||||
| @@ -775,7 +786,7 @@ def check_loss(): | ||||
|  | ||||
|     out_shape = (BATCH_SIZE, NUM_CLASSES) | ||||
|     out_master = torch.randn(out_shape, device=device) | ||||
|     target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) | ||||
|     target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) | ||||
|     torch.distributed.broadcast(out_master, src=0) | ||||
|     torch.distributed.broadcast(target_master, src=0) | ||||
|     out = torch.chunk(out_master, DEPTH, dim=0)[i] | ||||
| @@ -828,7 +839,7 @@ def check_vocab_parallel_loss(): | ||||
|  | ||||
|     out_shape = (BATCH_SIZE, NUM_CLASSES) | ||||
|     out_master = torch.randn(out_shape, device=device) | ||||
|     target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) | ||||
|     target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) | ||||
|     torch.distributed.broadcast(out_master, src=0) | ||||
|     torch.distributed.broadcast(target_master, src=0) | ||||
|     out = torch.chunk(out_master, DEPTH, dim=0)[i] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user