From 646b3c5a90ce904b2128cae467c2068f435a9df0 Mon Sep 17 00:00:00 2001
From: Hongxin Liu <lhx0217@gmail.com>
Date: Thu, 10 Oct 2024 14:34:45 +0800
Subject: [PATCH] [shardformer] fix linear 1d row and support uneven splits for
 fused qkv linear (#6084)

* [tp] hotfix linear row

* [tp] support uneven split for fused linear

* [tp] support sp for fused linear

* [tp] fix gpt2 mlp policy

* [tp] fix gather fused and add fused linear row
---
 .../modeling/policy/nopadding_baichuan.py     |   4 +-
 colossalai/shardformer/layer/__init__.py      |   3 +-
 colossalai/shardformer/layer/_operation.py    |   4 +-
 colossalai/shardformer/layer/linear.py        |  11 +-
 .../shardformer/layer/qkv_fused_linear.py     | 364 +++++++++++++++---
 colossalai/shardformer/policies/blip2.py      |   2 +-
 colossalai/shardformer/policies/gpt2.py       |   4 +-
 colossalai/shardformer/policies/sam.py        |   2 +-
 .../test_gpt2_qkv_fused_linear_1d.py          |  21 +-
 .../test_layer/test_qkv_fused_linear_1d.py    | 141 ++++---
 10 files changed, 399 insertions(+), 157 deletions(-)

diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py
index 37b5062e8..8528de75c 100644
--- a/colossalai/inference/modeling/policy/nopadding_baichuan.py
+++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py
@@ -57,7 +57,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
                         target_module=NopadBaichuanMLP,
                     ),
                     SubModuleReplacementDescription(
-                        suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
+                        suffix="self_attn.W_pack",
+                        target_module=FusedLinear1D_Col,
+                        kwargs={"split_sizes": [self.model.config.hidden_size] * 3},
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.o_proj",
diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py
index 8882a33c1..684993de6 100644
--- a/colossalai/shardformer/layer/__init__.py
+++ b/colossalai/shardformer/layer/__init__.py
@@ -6,7 +6,7 @@ from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHe
 from .loss import cross_entropy_1d, dist_cross_entropy
 from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
 from .parallel_module import ParallelModule
-from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
+from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
 
 __all__ = [
     "Embedding1D",
@@ -34,4 +34,5 @@ __all__ = [
     "RingAttention",
     "get_pad_info",
     "all_to_all_comm",
+    "FusedLinear1D_Row",
 ]
diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py
index aec823567..1d7a1f104 100644
--- a/colossalai/shardformer/layer/_operation.py
+++ b/colossalai/shardformer/layer/_operation.py
@@ -840,7 +840,7 @@ class _AllToAll(torch.autograd.Function):
         ctx.gather_dim = gather_dim
         ctx.fp8_communication = fp8_communication
         world_size = dist.get_world_size(process_group)
-        bsz, _, _ = input_.shape
+        bsz = input_.shape[0]
 
         # using all_to_all_single when batch size is 1
         if bsz == 1:
@@ -871,7 +871,7 @@ class _AllToAll(torch.autograd.Function):
         gather_dim = ctx.scatter_dim
         fp8_communication = ctx.fp8_communication
         world_size = dist.get_world_size(process_group)
-        bsz, _, _ = grad_output.shape
+        bsz = grad_output.shape[0]
 
         if bsz == 1:
             return_grad = _all_to_all_single(
diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py
index d77dd4965..52b0e79c6 100644
--- a/colossalai/shardformer/layer/linear.py
+++ b/colossalai/shardformer/layer/linear.py
@@ -428,11 +428,8 @@ class Linear1D_Row(ParallelModule):
                     handle.wait()
                 output = torch.cat(output_parallel_list, dim=-1)
         else:
-            if self.seq_parallel_mode is None:
-                output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
-                output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
-            elif self.seq_parallel_mode == "split_gather":
-                output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
+            if self.seq_parallel_mode == "split_gather":
+                output_parallel = F.linear(input_, self.weight)
                 output = reducescatter_forward_gather_backward(
                     output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
                 )
@@ -445,8 +442,8 @@ class Linear1D_Row(ParallelModule):
                     ring=True,
                 )
             else:
-                output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
-                output = reduce_forward(output_parallel, self.process_group)
+                output_parallel = F.linear(input_, self.weight)
+                output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
 
         if not self.skip_bias_add:
             if self.bias is not None:
diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py
index 6fd689908..a1e25ff3a 100644
--- a/colossalai/shardformer/layer/qkv_fused_linear.py
+++ b/colossalai/shardformer/layer/qkv_fused_linear.py
@@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple, Union
 import torch
 import torch.distributed as dist
 import torch.nn as nn
+import torch.nn.functional as F
 from torch import Tensor
 from torch.distributed import ProcessGroup
 from torch.nn.parameter import Parameter
@@ -24,7 +25,9 @@ from colossalai.tensor.d_tensor.api import (
 )
 
 from ._operation import (
-    gather_forward_split_backward,
+    gather_forward_reducescatter_backward,
+    linear_gather_forward_reducescatter_backward,
+    linear_reducescatter_forward_gather_backward,
     linear_with_async_comm,
     matmul_gather_forward_reducescatter_backward,
     matmul_with_async_comm,
@@ -44,21 +47,25 @@ __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col"
 
 
 def split_fused_qkv_in_gpt2_style(
-    qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
+    qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
 ):
     """
     The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
 
     Args:
         qkv (torch.Tensor): The fused qkv tensor.
-        n_fused (int): The number items fused together, defaults to 3 (query, key and value).
+        split_sizes (List[int]): The sizes of the split tensor.
         process_group (ProcessGroup): The process group for distributed communication.
         is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
     """
     # get the number of slice for the fused qkv
     rank = dist.get_rank(group=process_group)
     world_size = dist.get_world_size(group=process_group)
-    order = torch.arange(world_size * n_fused)
+    order = torch.arange(world_size * len(split_sizes))
+    new_split_sizes = []
+    for sz in split_sizes:
+        assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
+        new_split_sizes.extend([sz // world_size] * world_size)
 
     # split the fused qkv
     # from
@@ -66,9 +73,9 @@ def split_fused_qkv_in_gpt2_style(
     # to
     # [Q1, Q2, K1, K2, V1, V2]
     if is_transposed:
-        weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
+        weight_chunks = torch.split(qkv, new_split_sizes, dim=-1)
     else:
-        weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0)
+        weight_chunks = torch.split(qkv, new_split_sizes, dim=0)
 
     # rearrange the slice into the final order
     # from
@@ -85,18 +92,23 @@ def split_fused_qkv_in_gpt2_style(
 
 
 def gather_fused_qkv_in_gpt2_style(
-    qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
+    qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
 ):
     """
     The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
 
     Args:
         qkv (torch.Tensor): The fused qkv tensor.
-        n_fused (int): The number items fused together, defaults to 3 (query, key and value).
+        split_sizes (List[int]): The sizes of the split tensor.
         process_group (ProcessGroup): The process group for distributed communication.
         is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
     """
     world_size = dist.get_world_size(group=process_group)
+    new_split_sizes = []
+    for sz in split_sizes:
+        assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
+        new_split_sizes.append(sz // world_size)
+    new_split_sizes = new_split_sizes * world_size
 
     # gather the tensors
     # from
@@ -121,13 +133,13 @@ def gather_fused_qkv_in_gpt2_style(
     # to
     # [Q1, Q2, K1, K2, V1, V2]
     if is_transposed:
-        weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
+        weight_chunks = torch.split(gather_weight, new_split_sizes, dim=-1)
     else:
-        weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0)
+        weight_chunks = torch.split(gather_weight, new_split_sizes, dim=0)
 
     reordered_chunk_list = []
-    for i in range(n_fused):
-        reordered_chunk_list.extend(weight_chunks[i::n_fused])
+    for i in range(len(split_sizes)):
+        reordered_chunk_list.extend(weight_chunks[i :: len(split_sizes)])
 
     if is_transposed:
         reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
@@ -136,6 +148,42 @@ def gather_fused_qkv_in_gpt2_style(
     return reordered_gather_weight
 
 
+class _SplitForwardGatherBackwardFusedQKV(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
+        ctx.split_sizes = split_sizes
+        ctx.process_group = process_group
+        return split_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        grad_output = gather_fused_qkv_in_gpt2_style(
+            grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True
+        )
+        return grad_output, None, None
+
+
+def split_forward_gather_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
+    return _SplitForwardGatherBackwardFusedQKV.apply(qkv, split_sizes, process_group)
+
+
+class _GatherForwardSplitBackwardFusedQKV(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
+        ctx.split_sizes = split_sizes
+        ctx.process_group = process_group
+        return gather_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        grad_output = split_fused_qkv_in_gpt2_style(grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True)
+        return grad_output, None, None
+
+
+def gather_forward_split_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
+    return _GatherForwardSplitBackwardFusedQKV.apply(qkv, split_sizes, process_group)
+
+
 class GPT2FusedLinearConv1D_Col(ParallelModule):
     r"""Linear layer with column parallelism.
 
@@ -145,10 +193,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
     Args:
         in_features (int): size of each input sample.
         out_features (int): size of each output sample.
+        split_sizes (List[int]): The sizes of the split tensor.
         bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
         dtype (`torch.dtype`): The dtype of parameters, defaults to None.
         device (`torch.device`): The device of parameters, defaults to None.
-        n_fused (int): The number items fused, defaults to 3 (QKV).
         process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
         seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
         gather_output (bool, optional): If true, call all-gather on output and make Y available
@@ -169,6 +217,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
         self,
         in_features: int,
         out_features: int,
+        split_sizes: List[int],
         bias: bool = True,
         dtype: torch.dtype = None,
         device: torch.device = None,
@@ -178,7 +227,6 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
         seq_parallel_mode: str = None,
         overlap: bool = False,
         skip_bias_add: bool = False,
-        n_fused: int = 3,
         weight: Optional[Parameter] = None,
         bias_: Optional[Parameter] = None,
         weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
@@ -195,11 +243,15 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
         self.overlap = overlap
         self.skip_bias_add = skip_bias_add
         self.device = device
-        self.n_fused = n_fused
+        self.split_sizes = split_sizes
         self.process_group = process_group
         self.async_communication = async_communication
         self.fp8_communication = fp8_communication
 
+        assert (
+            sum(split_sizes) == out_features
+        ), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
+
         if skip_bias_add and not bias:
             raise ValueError("cannot skip bias addition if bias is None")
 
@@ -223,10 +275,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
             self.weight = weight
 
         def shard_fn(tensor):
-            return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
+            return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
 
         def gather_fn(tensor):
-            return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
+            return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
 
         if not is_customized_distributed_tensor(self.weight):
             with torch.no_grad():
@@ -252,7 +304,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
 
     @staticmethod
     def from_native_module(
-        module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+        module: nn.Module,
+        process_group: Union[ProcessGroup, List[ProcessGroup]],
+        split_sizes: List[int],
+        *args,
+        **kwargs,
     ) -> ParallelModule:
         r"""
         Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
@@ -260,7 +316,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
         Args:
             module (`nn.Linear`): The module to be converted.
             process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
-            n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
+            split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight.
         """
         LazyInitContext.materialize(module)
         # get the attributes
@@ -291,6 +347,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
             process_group=process_group,
             weight=module.weight,
             bias_=module.bias,
+            split_sizes=split_sizes,
             *args,
             **kwargs,
         )
@@ -354,9 +411,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
 
         if self.gather_output:
             # All-gather across the partitions.
-            output = gather_forward_split_backward(
-                output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
-            )
+            output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
         else:
             output = output_parallel
 
@@ -605,10 +660,10 @@ class FusedLinear1D_Col(ParallelModule):
     Args:
         in_features (int): size of each input sample.
         out_features (int): size of each output sample.
+        split_sizes (List[int]): The sizes of the split tensor.
         bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
         dtype (`torch.dtype`): The dtype of parameters, defaults to None.
         device (`torch.device`): The device of parameters, defaults to None.
-        n_fused (int): The number items fused, defaults to 3 (QKV).
         process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
         gather_output (bool, optional): If true, call all-gather on output and make Y available
                     to all GPUs, otherwise, every GPU will have its output
@@ -628,14 +683,16 @@ class FusedLinear1D_Col(ParallelModule):
         self,
         in_features: int,
         out_features: int,
+        split_sizes: List[int],
         bias: bool = True,
         dtype: torch.dtype = None,
         device: torch.device = None,
         process_group: ProcessGroup = None,
-        async_communication: bool = False,
         gather_output: bool = False,
+        seq_parallel_mode: str = None,
+        seq_parallel_dim: int = 1,
+        overlap: torch.cuda.Stream = None,
         skip_bias_add: bool = False,
-        n_fused: int = 3,
         weight: Optional[Parameter] = None,
         bias_: Optional[Parameter] = None,
         weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
@@ -647,13 +704,19 @@ class FusedLinear1D_Col(ParallelModule):
         self.in_features = in_features
         self.out_features = out_features
         self.gather_output = gather_output
+        self.seq_parallel_mode = seq_parallel_mode
+        self.seq_parallel_dim = seq_parallel_dim
+        self.overlap = overlap
         self.skip_bias_add = skip_bias_add
         self.device = device
-        self.n_fused = n_fused
+        self.split_sizes = split_sizes
         self.process_group = process_group
-        self.async_communication = async_communication
         self.fp8_communication = fp8_communication
 
+        assert (
+            sum(split_sizes) == out_features
+        ), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
+
         if skip_bias_add and not bias:
             raise ValueError("cannot skip bias addition if bias is None")
 
@@ -677,10 +740,10 @@ class FusedLinear1D_Col(ParallelModule):
             self.weight = weight
 
         def shard_fn(tensor):
-            return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
+            return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
 
         def gather_fn(tensor):
-            return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
+            return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
 
         if not is_customized_distributed_tensor(self.weight):
             with torch.no_grad():
@@ -706,7 +769,11 @@ class FusedLinear1D_Col(ParallelModule):
 
     @staticmethod
     def from_native_module(
-        module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs
+        module: nn.Module,
+        process_group: Union[ProcessGroup, List[ProcessGroup]],
+        split_sizes: List[int],
+        *args,
+        **kwargs,
     ) -> ParallelModule:
         r"""
         Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
@@ -714,7 +781,7 @@ class FusedLinear1D_Col(ParallelModule):
         Args:
             module (`nn.Linear`): The module to be converted.
             process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
-            n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
+            split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight.
         """
         LazyInitContext.materialize(module)
 
@@ -737,25 +804,11 @@ class FusedLinear1D_Col(ParallelModule):
             process_group=process_group,
             weight=module.weight,
             bias_=module.bias,
-            n_fused=n_fused,
+            split_sizes=split_sizes,
             *args,
             **kwargs,
         )
 
-        # # TODO: copy the sharded weights
-        # with torch.no_grad():
-        #     sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
-        #                                                    n_fused=n_fused,
-        #                                                    process_group=process_group,
-        #                                                    is_transposed=False)
-        #     linear_1d.weight.data.copy_(sharded_weight.data)
-
-        #     if bias:
-        #         sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
-        #                                                      n_fused=n_fused,
-        #                                                      process_group=process_group,
-        #                                                      is_transposed=False)
-        #         linear_1d.bias.data.copy_(sharded_bias.data)
         return linear_1d
 
     def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@@ -772,19 +825,30 @@ class FusedLinear1D_Col(ParallelModule):
             input_.shape, self.weight.shape, self.weight.shape[-1]
         )
         # Set up backprop all-reduce.
-        # input_parallel = reduce_backward(input_, self.process_group)
         input_parallel = input_
 
         # Matrix multiply.
         bias = self.bias if not self.skip_bias_add else None
 
-        output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
+        if self.seq_parallel_mode == "split_gather":
+            input_parallel = gather_forward_reducescatter_backward(
+                input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
+            )
+            output_parallel = linear_with_async_comm(
+                input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
+            )
+        elif self.seq_parallel_mode == "ring":
+            output_parallel = linear_gather_forward_reducescatter_backward(
+                input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
+            )
+        else:
+            output_parallel = linear_with_async_comm(
+                input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
+            )
 
         if self.gather_output:
             # All-gather across the partitions.
-            output = gather_forward_split_backward(
-                output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
-            )
+            output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
         else:
             output = output_parallel
 
@@ -792,3 +856,201 @@ class FusedLinear1D_Col(ParallelModule):
             return output, self.bias
         else:
             return output
+
+
+class FusedLinear1D_Row(ParallelModule):
+    r"""Linear layer with row parallelism
+
+    Args:
+        in_features (int): size of each input sample.
+        out_features (int): size of each output sample.
+        bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
+        dtype (`torch.dtype`): The dtype of parameters, defaults to None.
+        parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
+        process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
+        seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
+        seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
+        skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
+            which is preserved for kernel fusion, defaults to False
+        weight_initializer (:class:`typing.Callable`, optional):
+            The initializer of weight, defaults to kaiming uniform initializer.
+        bias_initializer (:class:`typing.Callable`, optional):
+            The initializer of bias, defaults to xavier uniform initializer.
+
+    More details about ``initializer`` please refer to
+    `init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
+    """
+
+    def __init__(
+        self,
+        in_features: int,
+        out_features: int,
+        split_sizes: List[int],
+        bias: bool = True,
+        dtype: torch.dtype = None,
+        device: torch.device = None,
+        process_group: ProcessGroup = None,
+        seq_parallel_mode: str = None,
+        seq_parallel_dim: int = 1,
+        parallel_input: bool = True,
+        skip_bias_add: bool = False,
+        weight: Optional[Parameter] = None,
+        bias_: Optional[Parameter] = None,
+        weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
+        bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
+        fp8_communication: bool = False,
+    ):
+        super().__init__()
+        # Keep input parameters
+        self.in_features = in_features
+        self.out_features = out_features
+        self.split_sizes = split_sizes
+        self.parallel_input = parallel_input
+        self.skip_bias_add = skip_bias_add
+        self.process_group = process_group
+        self.seq_parallel_mode = seq_parallel_mode
+        self.seq_parallel_dim = seq_parallel_dim
+        self.num_partitions = dist.get_world_size(self.process_group)
+        self.fp8_communication = fp8_communication
+
+        assert (
+            sum(split_sizes) == in_features
+        ), f"The sum of split_sizes({sum(split_sizes)}) should be equal to in_features({in_features})."
+
+        if skip_bias_add and not bias:
+            raise ValueError("cannot skip bias addition if bias is None")
+
+        # offset the seed with randomizer index and rank
+        seed = torch.random.initial_seed()
+        self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
+
+        # sanity check
+        if weight is not None:
+            assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
+        else:
+            assert bias_ is None, "bias_ must be None if weight is None"
+
+        # Parameters.
+        if weight is None:
+            # Initialize weight.
+            factory_kwargs = {"device": device, "dtype": dtype}
+            self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
+        else:
+            weight.data = weight.data.to(device=device, dtype=dtype)
+            self.weight = weight
+
+        def shard_fn(tensor):
+            return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
+
+        def gather_fn(tensor):
+            return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
+
+        if not is_customized_distributed_tensor(self.weight):
+            with torch.no_grad():
+                sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
+            customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
+
+        if bias:
+            if bias_ is None:
+                self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
+            else:
+                bias_.data = bias_.data.to(device=device, dtype=dtype)
+                self.bias = bias_
+        else:
+            self.bias = None
+
+        if weight is None:
+            with self.randomizer.fork_rng(enable_cpu=True):
+                self.reset_parameters(weight_initializer, bias_initializer)
+
+    @staticmethod
+    def from_native_module(
+        module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], split_sizes: List[int], **kwargs
+    ) -> ParallelModule:
+        r"""
+        Convert a native PyTorch linear layer to a parallelized linear layer.
+        """
+        LazyInitContext.materialize(module)
+        # get the attributes
+        in_features = module.in_features
+        out_features = module.out_features
+        bias = module.bias is not None
+        device = module.weight.device
+
+        # ensure only one process group is passed
+        if isinstance(process_group, (list, tuple)):
+            assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
+            process_group = process_group[0]
+
+        linear_1d = FusedLinear1D_Row(
+            in_features=in_features,
+            out_features=out_features,
+            bias=bias,
+            device=device,
+            process_group=process_group,
+            weight=module.weight,
+            bias_=module.bias,
+            split_sizes=split_sizes,
+            **kwargs,
+        )
+
+        return linear_1d
+
+    @torch.no_grad()
+    def reset_parameters(self, weight_initializer, bias_initializer) -> None:
+        fan_in, fan_out = self.in_features, self.out_features
+        weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
+
+        if self.bias is not None:
+            bias_initializer(self.bias, fan_in=fan_in)
+            if self.process_group is None:
+                src_rank = 0
+            else:
+                src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
+
+            origin_device = self.bias.device
+            bias = self.bias.cuda()
+            dist.broadcast(bias, src=src_rank, group=self.process_group)
+            bias = bias.to(origin_device)
+            self.bias.copy_(bias)
+
+    def forward(self, input_: Tensor) -> Tensor:
+        # Set up backprop all-reduce.
+        if self.parallel_input:
+            assert (
+                input_.shape[-1] == self.weight.shape[-1]
+            ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
+                input_.shape, self.weight.shape, self.weight.shape[-1]
+            )
+            input_ = input_
+        else:
+            assert (
+                divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
+            ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
+                input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
+            )
+            input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group)
+
+        if self.seq_parallel_mode == "split_gather":
+            output_parallel = F.linear(input_, self.weight)
+            output = reducescatter_forward_gather_backward(
+                output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
+            )
+        elif self.seq_parallel_mode == "ring":
+            output = linear_reducescatter_forward_gather_backward(
+                input_,
+                self.weight,
+                process_group=self.process_group,
+                dim=self.seq_parallel_dim,
+                ring=True,
+            )
+        else:
+            output_parallel = F.linear(input_, self.weight)
+            output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
+
+        if not self.skip_bias_add:
+            if self.bias is not None:
+                output = output + self.bias
+            return output
+        else:
+            return output, self.bias
diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py
index da798f6a0..2e73d5c2a 100644
--- a/colossalai/shardformer/policies/blip2.py
+++ b/colossalai/shardformer/policies/blip2.py
@@ -71,7 +71,7 @@ class BlipPolicy(Policy):
                         suffix="self_attn.qkv",
                         target_module=col_nn.FusedLinear1D_Col,
                         kwargs={
-                            "n_fused": 3,
+                            "split_sizes": [self.model.config.vision_config.hidden_size] * 3,
                             "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py
index d9233be9a..faacf91b2 100644
--- a/colossalai/shardformer/policies/gpt2.py
+++ b/colossalai/shardformer/policies/gpt2.py
@@ -92,7 +92,7 @@ class GPT2Policy(Policy):
                         suffix="attn.c_attn",
                         target_module=col_nn.GPT2FusedLinearConv1D_Col,
                         kwargs={
-                            "n_fused": 3,
+                            "split_sizes": [self.model.config.hidden_size] * 3,
                             "seq_parallel_mode": sp_mode,
                             "overlap": overlap,
                             "fp8_communication": self.shard_config.fp8_communication,
@@ -107,7 +107,7 @@ class GPT2Policy(Policy):
                         suffix="mlp.c_fc",
                         target_module=col_nn.GPT2FusedLinearConv1D_Col,
                         kwargs={
-                            "n_fused": 1,
+                            "split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
                             "seq_parallel_mode": sp_mode,
                             "overlap": overlap,
                             "skip_bias_add": self.enable_bias_gelu_fused,
diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py
index 674fe5e58..a94cc9119 100644
--- a/colossalai/shardformer/policies/sam.py
+++ b/colossalai/shardformer/policies/sam.py
@@ -42,7 +42,7 @@ class SamPolicy(Policy):
                         suffix="attn.qkv",
                         target_module=col_nn.FusedLinear1D_Col,
                         kwargs={
-                            "n_fused": 3,
+                            "split_sizes": [self.model.config.vision_config.hidden_size] * 3,
                             "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
index 5aa8584a0..923075e0e 100644
--- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
+++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
@@ -41,21 +41,6 @@ class Conv1D(nn.Module):
         return x
 
 
-def rearrange(tensor: torch.Tensor, dim: int):
-    tensor = tensor.clone()
-    world_size = 2
-    order = torch.arange(world_size * 3)
-    new_order = []
-    for i in range(world_size):
-        new_order.append(order[i::world_size])
-    new_order = torch.cat(new_order)
-
-    tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
-    rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
-    rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
-    return rearanged_tensor
-
-
 def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
     ctx = LazyInitContext() if lazy_init else nullcontext()
     linear = Conv1D(192, 48).cuda()
@@ -66,7 +51,7 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
         process_group=None,
         gather_output=True,
         seq_parallel_mode=seq_parallel_mode,
-        n_fused=3,
+        split_sizes=[64] * 3,
         overlap=overlap,
     )
 
@@ -88,13 +73,13 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
         x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
     )
     gather_out = linear_conv_col(x_for_shard)
-    assert_close(rearrange(out, -1), gather_out)
+    assert_close(out, gather_out)
 
     # check backward correctness
     out.sum().backward()
     gather_out.sum().backward()
 
-    target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
+    target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [64] * 3, None, True)
     assert_close(target_grad, linear_conv_col.weight.grad)
 
 
diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
index dc14fd591..fccba564f 100644
--- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
+++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
@@ -2,13 +2,12 @@ import os
 from contextlib import nullcontext
 
 import torch
-import torch.distributed as dist
 import torch.nn as nn
 from torch.testing import assert_close
 
 import colossalai
 from colossalai.lazy import LazyInitContext
-from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
+from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
 from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 
@@ -16,93 +15,55 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
 
 
-class Conv1D(nn.Module):
-    """
-    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
-
-    Basically works like a linear layer but the weights are transposed.
-
-    Args:
-        nf (`int`): The number of output features.
-        nx (`int`): The number of input features.
-    """
-
-    def __init__(self, nf, nx):
-        super().__init__()
-        self.nf = nf
-        self.weight = nn.Parameter(torch.empty(nx, nf))
-        self.bias = nn.Parameter(torch.zeros(nf))
-        nn.init.normal_(self.weight, std=0.02)
-
-    def forward(self, x):
-        size_out = x.size()[:-1] + (self.nf,)
-        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
-        x = x.view(size_out)
-        return x
-
-
-def rearrange(tensor: torch.Tensor, dim: int):
-    tensor = tensor.clone()
-    world_size = 2
-    order = torch.arange(world_size * 3)
-    new_order = []
-    for i in range(world_size):
-        new_order.append(order[i::world_size])
-    new_order = torch.cat(new_order)
-
-    tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
-    rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
-    rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
-    return rearanged_tensor
-
-
 @parameterize("lazy_init", [False, True])
-def check_linear_conv_1d_col(lazy_init: bool):
+def check_linear_1d_col(lazy_init: bool):
     ctx = LazyInitContext() if lazy_init else nullcontext()
-    linear = Conv1D(192, 48).cuda()
+    linear = nn.Linear(8, 80).cuda()
     with ctx:
-        linear_copy = Conv1D(192, 48).cuda()
-    linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(
-        linear_copy, process_group=None, gather_output=True, n_fused=3
+        linear_copy = nn.Linear(8, 80).cuda()
+    linear_col = FusedLinear1D_Col.from_native_module(
+        linear_copy, process_group=None, gather_output=True, split_sizes=[32, 32, 16]
     )
 
-    assert linear.weight.shape == torch.Size([48, 192])
-    assert linear.bias.shape == torch.Size([192])
-    assert linear_conv_col.weight.shape == torch.Size([48, 96])
-    assert linear_conv_col.bias.shape == torch.Size([96])
-    assert linear_copy.weight is linear_conv_col.weight
-    assert linear_copy.bias is linear_conv_col.bias
+    assert linear.weight.shape == torch.Size([80, 8])
+    assert linear.bias.shape == torch.Size([80])
+    assert linear_col.weight.shape == torch.Size([40, 8])
+    assert linear_col.bias.shape == torch.Size([40])
+    assert linear_copy.weight is linear_col.weight
+    assert linear_copy.bias is linear_col.bias
 
     # ensure weights are reversibly loadable
-    linear_conv_col.load_state_dict(linear.state_dict())
-    linear.load_state_dict(linear_conv_col.state_dict())
+    linear_col.load_state_dict(linear.state_dict())
+    linear.load_state_dict(linear_col.state_dict())
 
     # check computation correctness
-    x = torch.rand(4, 48).cuda()
+    x = torch.rand(4, 8).cuda()
     out = linear(x)
-    gather_out = linear_conv_col(x)
-    assert_close(rearrange(out, 1), gather_out)
+    gather_out = linear_col(x)
+    assert_close(out, gather_out)
 
     # check backward correctness
     out.sum().backward()
     gather_out.sum().backward()
 
-    target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
-    assert_close(target_grad, linear_conv_col.weight.grad)
+    target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, False)
+    assert_close(target_grad, linear_col.weight.grad)
 
 
 @parameterize("lazy_init", [False, True])
-def check_linear_conv_1d_row(lazy_init: bool):
+def check_linear_1d_row(lazy_init: bool):
     ctx = LazyInitContext() if lazy_init else nullcontext()
 
-    linear = Conv1D(192, 48).cuda()
+    linear = nn.Linear(80, 8).cuda()
     with ctx:
-        linear_copy = Conv1D(192, 48).cuda()
-    linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
+        linear_copy = nn.Linear(80, 8).cuda()
+    linear_row = FusedLinear1D_Row.from_native_module(
+        linear_copy, process_group=None, split_sizes=[32, 32, 16], parallel_input=False
+    )
 
-    assert linear.weight.shape == torch.Size([48, 192])
-    assert linear_row.weight.shape == torch.Size([24, 192])
-    assert linear_row.bias.shape == torch.Size([192])
+    assert linear.weight.shape == torch.Size([8, 80])
+    assert linear_row.weight.shape == torch.Size([8, 40])
+    assert linear_row.bias.shape == torch.Size([8])
     assert linear_copy.weight is linear_row.weight
     assert linear_copy.bias is linear_row.bias
 
@@ -111,7 +72,7 @@ def check_linear_conv_1d_row(lazy_init: bool):
     linear.load_state_dict(linear_row.state_dict())
 
     # check computation correctness
-    x = torch.rand(4, 48).cuda()
+    x = torch.rand(4, 80).cuda()
     out = linear(x)
     gather_out = linear_row(x)
     assert_close(out, gather_out)
@@ -120,17 +81,51 @@ def check_linear_conv_1d_row(lazy_init: bool):
     out.sum().backward()
     gather_out.sum().backward()
 
-    rank = dist.get_rank()
-    target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
+    target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, True)
     assert_close(target_grad, linear_row.weight.grad)
 
 
+@parameterize("lazy_init", [False, True])
+def check_linear_1d_col_row(lazy_init: bool):
+    ctx = LazyInitContext() if lazy_init else nullcontext()
+
+    linear1 = nn.Linear(8, 80).cuda()
+    linear2 = nn.Linear(80, 8).cuda()
+    with ctx:
+        linear1_copy = nn.Linear(8, 80).cuda()
+        linear2_copy = nn.Linear(80, 8).cuda()
+    linear_col = FusedLinear1D_Col.from_native_module(linear1_copy, process_group=None, split_sizes=[32, 32, 16])
+    linear_row = FusedLinear1D_Row.from_native_module(
+        linear2_copy,
+        process_group=None,
+        split_sizes=[32, 32, 16],
+    )
+    # ensure weights are reversibly loadable
+    linear_col.load_state_dict(linear1.state_dict())
+    linear_row.load_state_dict(linear2.state_dict())
+
+    # check computation correctness
+    x = torch.rand(4, 8).cuda()
+    target_out = linear2(linear1(x))
+    out = linear_row(linear_col(x))
+    assert_close(out, target_out)
+
+    # check backward correctness
+    target_out.sum().backward()
+    out.sum().backward()
+
+    target_grad1 = split_fused_qkv_in_gpt2_style(linear1.weight.grad, [32, 32, 16], None, False)
+    assert_close(target_grad1, linear_col.weight.grad)
+    target_grad2 = split_fused_qkv_in_gpt2_style(linear2.weight.grad, [32, 32, 16], None, True)
+    assert_close(target_grad2, linear_row.weight.grad)
+
+
 def run_dist(rank, world_size, port):
     colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
 
-    # test for linear conv
-    check_linear_conv_1d_col()
-    check_linear_conv_1d_row()
+    check_linear_1d_col()
+    check_linear_1d_row()
+    check_linear_1d_col_row()
 
 
 @rerun_if_address_is_in_use()