diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py index d6fe58f1b..592924bd4 100644 --- a/colossalai/nn/layer/parallel_2d/_operation.py +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -58,6 +58,7 @@ def matmul_2d( class _Classifier2D(torch.autograd.Function): + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -76,7 +77,7 @@ class _Classifier2D(torch.autograd.Function): pipeline_parallel_size: int, tensor_parallel_size: int, ) -> Tensor: - + A = A.clone().detach() A_shape = A.shape A = A.reshape((-1, A_shape[-1])) B_shape = B.shape @@ -181,6 +182,7 @@ class Matmul_AB_2D(torch.autograd.Function): The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -308,6 +310,7 @@ class Matmul_ABT_2D(torch.autograd.Function): The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_. """ + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -440,6 +443,7 @@ class Matmul_ATB_2D(torch.autograd.Function): The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_. """ + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -552,6 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function): class _Add_Bias_2D(torch.autograd.Function): + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -633,6 +638,7 @@ def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, ro class _Layernorm_2D(torch.autograd.Function): + @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, @@ -689,6 +695,7 @@ def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, r class _AllGatherTensor2D(torch.autograd.Function): + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: @@ -742,6 +749,7 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: class _ReduceTensor2D(torch.autograd.Function): + @staticmethod def forward(ctx, input_, parallel_mode): return all_reduce(input_, parallel_mode) @@ -766,6 +774,7 @@ def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: class _ReduceScatterTensor2D(torch.autograd.Function): + @staticmethod def forward(ctx, input_, dim, parallel_mode): ctx.dim = dim @@ -793,11 +802,12 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo world_size = gpc.get_world_size(parallel_mode) assert dim_size % world_size == 0, \ f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' - + return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode) class _ReduceByBatch2D(torch.autograd.Function): + @staticmethod def symbolic(graph, input_, reduce_mean: bool = False): output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) @@ -834,4 +844,4 @@ def reduce_by_batch_2d(input_, reduce_mean: bool = False) -> Tensor: reduce_mean (bool, optional): If set to ``True``, it will divide the output by column parallel size, default to False. """ - return _ReduceByBatch2D.apply(input_, reduce_mean) \ No newline at end of file + return _ReduceByBatch2D.apply(input_, reduce_mean) diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/nn/layer/parallel_2p5d/_operation.py index 38f6bba72..0bcc8ecee 100644 --- a/colossalai/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/nn/layer/parallel_2p5d/_operation.py @@ -23,25 +23,26 @@ def get_parallel_rank(parallel_mode: ParallelMode): class _Classifier2p5D(torch.autograd.Function): + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( - ctx: Any, - A: Tensor, - B: Tensor, - bias, - tesseract_dim: int, - out_shape: Tuple[int, ...], - row_rank: int, - col_rank: int, - row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode, - data_parallel_rank: int, - pipeline_parallel_rank: int, - pipeline_parallel_size: int, - tensor_parallel_size: int, + ctx: Any, + A: Tensor, + B: Tensor, + bias, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, ) -> Tensor: - + A = A.clone().detach() A_shape = A.shape A = A.reshape((-1, A_shape[-1])) B_shape = B.shape @@ -509,6 +510,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function): class _Add_Bias_2p5D(torch.autograd.Function): + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, @@ -689,6 +691,7 @@ def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, class _AllGatherTensor2p5D(torch.autograd.Function): + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: @@ -777,6 +780,7 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: class _ReduceTensor2p5D(torch.autograd.Function): + @staticmethod def forward(ctx, input_, parallel_mode): return all_reduce(input_, parallel_mode) @@ -801,6 +805,7 @@ def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: class _ReduceScatterTensor2p5D(torch.autograd.Function): + @staticmethod def forward(ctx, input_, dim, parallel_mode): ctx.dim = dim @@ -833,6 +838,7 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel class _RreduceByBatch2p5D(torch.autograd.Function): + @staticmethod def symbolic(graph, input_, reduce_mean: bool = False): output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)