From 5f0924361de4e87f05cbf8aadf9fbd698873a53d Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 31 Oct 2024 08:18:28 +0000 Subject: [PATCH] [fix] fix linear (no tp) ops func name; --- colossalai/shardformer/layer/__init__.py | 4 ++-- colossalai/shardformer/layer/_operation.py | 10 ++++----- colossalai/shardformer/layer/linear.py | 21 +++++-------------- colossalai/shardformer/policies/mixtral.py | 15 +++---------- examples/language/llama/benchmark.py | 4 ++-- .../test_layer/test_linear_1d.py | 6 +++--- 6 files changed, 19 insertions(+), 41 deletions(-) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 613ce73c3..4fc714e57 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -2,7 +2,7 @@ from ._operation import all_to_all_comm from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D -from .linear import Linear1D, Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D +from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D from .loss import cross_entropy_1d, dist_cross_entropy from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule @@ -11,7 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2 __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", - "Linear1D", + "LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row", "GPT2FusedLinearConv1D_Col", diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 46f50ef02..8a068b78c 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -235,17 +235,16 @@ class LinearWithAsyncCommunication(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None, None -class LinearBase(torch.autograd.Function): +class LinearWithGradAccum(torch.autograd.Function): """ Linear layer baseline (no tensor parallel version). """ @staticmethod - def forward(ctx, input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False): + def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.async_grad_allreduce = async_grad_allreduce - ctx.fp8_communication = fp8_communication ctx.use_zbv = use_zbv if bias is not None: output = F.linear(input_, weight, bias) @@ -258,7 +257,6 @@ class LinearBase(torch.autograd.Function): def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias - ctx.fp8_communication use_zbv = ctx.use_zbv def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): @@ -1201,8 +1199,8 @@ def linear_with_async_comm( ) -def linear_base(input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False): - return LinearBase.apply(input_, weight, bias, async_grad_allreduce, fp8_communication, use_zbv) +def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False): + return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv) def linear_gather_forward_reducescatter_backward( diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index cb1496a0b..040a93e5a 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -25,10 +25,10 @@ from colossalai.tensor.d_tensor.api import ( from ._operation import ( gather_forward_reducescatter_backward, gather_forward_split_backward, - linear_base, linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, linear_with_async_comm, + linear_with_grad_accum, reduce_forward, reducescatter_forward_gather_backward, split_forward_gather_backward, @@ -36,10 +36,10 @@ from ._operation import ( from .parallel_module import PaddingParallelModule, ParallelModule from .utils import create_randomizer_with_offset -__all__ = ["Linear1D", "Linear1D_Col", "Linear1D_Row"] +__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"] -class Linear1D(ParallelModule): +class LinearWithGradAccum(ParallelModule): r"""Linear layer with no parallelism. Args: @@ -69,16 +69,11 @@ class Linear1D(ParallelModule): bias: bool = True, dtype: torch.dtype = None, device: torch.device = None, - gather_output: bool = False, - seq_parallel_mode: str = None, - seq_parallel_dim: int = 1, - overlap: torch.cuda.Stream = None, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - fp8_communication: bool = False, use_zbv: bool = False, **kwargs, ): @@ -87,13 +82,8 @@ class Linear1D(ParallelModule): # Keep input parameters self.in_features = in_features self.out_features = out_features - self.gather_output = gather_output - self.seq_parallel_mode = seq_parallel_mode - self.seq_parallel_dim = seq_parallel_dim - self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device - self.fp8_communication = fp8_communication self.use_zbv = use_zbv if skip_bias_add and not bias: @@ -143,7 +133,7 @@ class Linear1D(ParallelModule): bias = module.bias is not None device = module.weight.device - linear_1d = Linear1D( + linear_1d = LinearWithGradAccum( in_features=in_features, out_features=out_features, bias=bias, @@ -174,12 +164,11 @@ class Linear1D(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_base( + output_parallel = linear_with_grad_accum( input_parallel, self.weight, bias, False, - fp8_communication=self.fp8_communication, use_zbv=self.use_zbv, ) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 11291169a..ece72d929 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -52,10 +52,7 @@ class MixtralPolicy(Policy): sp_group = self.shard_config.sequence_parallel_process_group or None sp_partial_derived = sp_mode in ["split_gather", "ring"] tp_size = self.shard_config.tensor_parallel_size - if self.pipeline_stage_manager: - use_zbv = self.pipeline_stage_manager.use_zbv - else: - use_zbv = False + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv # modified for both SP and TP num_q_heads = self.model.config.num_attention_heads @@ -334,10 +331,7 @@ class MixtralModelPolicy(MixtralPolicy): class MixtralForCausalLMPolicy(MixtralPolicy): def module_policy(self): policy = super().module_policy() - if self.pipeline_stage_manager: - use_zbv = self.pipeline_stage_manager.use_zbv - else: - use_zbv = False + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: # add a new item for causal lm @@ -400,10 +394,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy): from transformers import MixtralForSequenceClassification policy = super().module_policy() - if self.pipeline_stage_manager: - use_zbv = self.pipeline_stage_manager.use_zbv - else: - use_zbv = False + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 68ceb9ac1..4976f0c37 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -366,10 +366,10 @@ def main(): ) loss = outputs["loss"] if args.pp_style == "zbv": - if dist.get_rank() == 0: + if coordinator.is_master(): print(f"Step {step} loss: {loss}") else: - if dist.get_rank() == dist.get_world_size() - 1: + if coordinator.is_last_process(): print(f"Step {step} loss: {loss}") optimizer.step() optimizer.zero_grad() diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 0556bc986..773799c1c 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -9,7 +9,7 @@ from torch.testing import assert_close import colossalai from colossalai.lazy import LazyInitContext from colossalai.pipeline.weight_grad_store import WeightGradStore -from colossalai.shardformer.layer import Linear1D, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -124,7 +124,7 @@ def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: b linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_base = Linear1D.from_native_module( + linear_base = LinearWithGradAccum.from_native_module( linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False ) assert linear_base.weight.shape == torch.Size([128, 32]) @@ -164,7 +164,7 @@ def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_base = Linear1D.from_native_module( + linear_base = LinearWithGradAccum.from_native_module( linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True ) assert linear_base.weight.shape == torch.Size([128, 32])