diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md
index 0a9b5293d..76813a4a3 100644
--- a/colossalai/inference/README.md
+++ b/colossalai/inference/README.md
@@ -18,7 +18,7 @@
 
 
 ## 📌 Introduction
-ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
+ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
 
 <p align="center">
 <img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/colossal-inference-v1-1.png" width=1000/>
@@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below.
   journal={arXiv},
   year={2023}
 }
+
+# Distrifusion
+@InProceedings{Li_2024_CVPR,
+    author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song},
+    title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
+    booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+    month={June},
+    year={2024},
+    pages={7183-7193}
+}
 ```
diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py
index 1beb86874..072ddbcfd 100644
--- a/colossalai/inference/config.py
+++ b/colossalai/inference/config.py
@@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
         enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
         start_token_size(int): The size of the start tokens, when using StreamingLLM.
         generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
+        patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
     """
 
     # NOTE: arrange configs according to their importance and frequency of usage
@@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
     start_token_size: int = 4
     generated_token_size: int = 512
 
+    # Acceleration for Diffusion Model(PipeFusion or Distrifusion)
+    patched_parallelism_size: int = 1  # for distrifusion
+    # pipeFusion_m_size: int = 1  # for pipefusion
+    # pipeFusion_n_size: int = 1  # for pipefusion
+
     def __post_init__(self):
         self.max_context_len_to_capture = self.max_input_len + self.max_output_len
         self._verify_config()
@@ -288,6 +294,14 @@ class InferenceConfig(RPC_PARAM):
         # Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
         self.start_token_size = self.block_size
 
+        # check Distrifusion
+        # TODO(@lry89757) need more detailed check
+        if self.patched_parallelism_size > 1:
+            # self.use_patched_parallelism = True
+            self.tp_size = (
+                self.patched_parallelism_size
+            )  # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size
+
         # check prompt template
         if self.prompt_template is None:
             return
@@ -324,6 +338,7 @@ class InferenceConfig(RPC_PARAM):
             use_cuda_kernel=self.use_cuda_kernel,
             use_spec_dec=self.use_spec_dec,
             use_flash_attn=use_flash_attn,
+            patched_parallelism_size=self.patched_parallelism_size,
         )
         return model_inference_config
 
@@ -396,6 +411,7 @@ class ModelShardInferenceConfig:
     use_cuda_kernel: bool = False
     use_spec_dec: bool = False
     use_flash_attn: bool = False
+    patched_parallelism_size: int = 1  # for diffusion model, Distrifusion Technique
 
 
 @dataclass
diff --git a/colossalai/inference/core/diffusion_engine.py b/colossalai/inference/core/diffusion_engine.py
index 75b9889bf..8bed508cb 100644
--- a/colossalai/inference/core/diffusion_engine.py
+++ b/colossalai/inference/core/diffusion_engine.py
@@ -11,7 +11,7 @@ from torch import distributed as dist
 from colossalai.accelerator import get_accelerator
 from colossalai.cluster import ProcessGroupMesh
 from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
-from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
 from colossalai.inference.modeling.policy import model_policy_map
 from colossalai.inference.struct import DiffusionSequence
 from colossalai.inference.utils import get_model_size, get_model_type
diff --git a/colossalai/inference/modeling/models/diffusion.py b/colossalai/inference/modeling/layers/diffusion.py
similarity index 100%
rename from colossalai/inference/modeling/models/diffusion.py
rename to colossalai/inference/modeling/layers/diffusion.py
diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py
new file mode 100644
index 000000000..ea97cceef
--- /dev/null
+++ b/colossalai/inference/modeling/layers/distrifusion.py
@@ -0,0 +1,626 @@
+# Code refer and adapted from:
+# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers
+# https://github.com/PipeFusion/PipeFusion
+
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from diffusers.models import attention_processor
+from diffusers.models.attention import Attention
+from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed
+from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
+from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
+from torch import nn
+from torch.distributed import ProcessGroup
+
+from colossalai.inference.config import ModelShardInferenceConfig
+from colossalai.logging import get_dist_logger
+from colossalai.shardformer.layer.parallel_module import ParallelModule
+from colossalai.utils import get_current_device
+
+try:
+    from flash_attn import flash_attn_func
+
+    HAS_FLASH_ATTN = True
+except ImportError:
+    HAS_FLASH_ATTN = False
+
+
+logger = get_dist_logger(__name__)
+
+
+# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py
+def PixArtAlphaTransformer2DModel_forward(
+    self: PixArtTransformer2DModel,
+    hidden_states: torch.Tensor,
+    encoder_hidden_states: Optional[torch.Tensor] = None,
+    timestep: Optional[torch.LongTensor] = None,
+    added_cond_kwargs: Dict[str, torch.Tensor] = None,
+    class_labels: Optional[torch.LongTensor] = None,
+    cross_attention_kwargs: Dict[str, Any] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    encoder_attention_mask: Optional[torch.Tensor] = None,
+    return_dict: bool = True,
+):
+    assert hasattr(
+        self, "patched_parallel_size"
+    ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
+
+    if cross_attention_kwargs is not None:
+        if cross_attention_kwargs.get("scale", None) is not None:
+            logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+    # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+    #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+    #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+    # expects mask of shape:
+    #   [batch, key_tokens]
+    # adds singleton query_tokens dimension:
+    #   [batch,                    1, key_tokens]
+    # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+    #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+    #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+    if attention_mask is not None and attention_mask.ndim == 2:
+        # assume that mask is expressed as:
+        #   (1 = keep,      0 = discard)
+        # convert mask into a bias that can be added to attention scores:
+        #       (keep = +0,     discard = -10000.0)
+        attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+        attention_mask = attention_mask.unsqueeze(1)
+
+    # convert encoder_attention_mask to a bias the same way we do for attention_mask
+    if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+        encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+        encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+    # 1. Input
+    batch_size = hidden_states.shape[0]
+    height, width = (
+        hidden_states.shape[-2] // self.config.patch_size,
+        hidden_states.shape[-1] // self.config.patch_size,
+    )
+    hidden_states = self.pos_embed(hidden_states)
+
+    timestep, embedded_timestep = self.adaln_single(
+        timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+    )
+
+    if self.caption_projection is not None:
+        encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+        encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+    # 2. Blocks
+    for block in self.transformer_blocks:
+        hidden_states = block(
+            hidden_states,
+            attention_mask=attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            timestep=timestep,
+            cross_attention_kwargs=cross_attention_kwargs,
+            class_labels=class_labels,
+        )
+
+    # 3. Output
+    shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk(
+        2, dim=1
+    )
+    hidden_states = self.norm_out(hidden_states)
+    # Modulation
+    hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
+    hidden_states = self.proj_out(hidden_states)
+    hidden_states = hidden_states.squeeze(1)
+
+    # unpatchify
+    hidden_states = hidden_states.reshape(
+        shape=(
+            -1,
+            height // self.patched_parallel_size,
+            width,
+            self.config.patch_size,
+            self.config.patch_size,
+            self.out_channels,
+        )
+    )
+    hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+    output = hidden_states.reshape(
+        shape=(
+            -1,
+            self.out_channels,
+            height // self.patched_parallel_size * self.config.patch_size,
+            width * self.config.patch_size,
+        )
+    )
+
+    # enable Distrifusion Optimization
+    if hasattr(self, "patched_parallel_size"):
+        from torch import distributed as dist
+
+        if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
+            self.output_buffer = torch.empty_like(output)
+        if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
+            self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
+        output = output.contiguous()
+        dist.all_gather(self.buffer_list, output, async_op=False)
+        torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
+        output = self.output_buffer
+
+    return (output,)
+
+
+# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py
+def SD3Transformer2DModel_forward(
+    self: SD3Transformer2DModel,
+    hidden_states: torch.FloatTensor,
+    encoder_hidden_states: torch.FloatTensor = None,
+    pooled_projections: torch.FloatTensor = None,
+    timestep: torch.LongTensor = None,
+    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+    return_dict: bool = True,
+) -> Union[torch.FloatTensor]:
+
+    assert hasattr(
+        self, "patched_parallel_size"
+    ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
+
+    height, width = hidden_states.shape[-2:]
+
+    hidden_states = self.pos_embed(hidden_states)  # takes care of adding positional embeddings too.
+    temb = self.time_text_embed(timestep, pooled_projections)
+    encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+    for block in self.transformer_blocks:
+        encoder_hidden_states, hidden_states = block(
+            hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
+        )
+
+    hidden_states = self.norm_out(hidden_states, temb)
+    hidden_states = self.proj_out(hidden_states)
+
+    # unpatchify
+    patch_size = self.config.patch_size
+    height = height // patch_size // self.patched_parallel_size
+    width = width // patch_size
+
+    hidden_states = hidden_states.reshape(
+        shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
+    )
+    hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+    output = hidden_states.reshape(
+        shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
+    )
+
+    # enable Distrifusion Optimization
+    if hasattr(self, "patched_parallel_size"):
+        from torch import distributed as dist
+
+        if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
+            self.output_buffer = torch.empty_like(output)
+        if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
+            self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
+        output = output.contiguous()
+        dist.all_gather(self.buffer_list, output, async_op=False)
+        torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
+        output = self.output_buffer
+
+    return (output,)
+
+
+# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py
+class DistrifusionPatchEmbed(ParallelModule):
+    def __init__(
+        self,
+        module: PatchEmbed,
+        process_group: Union[ProcessGroup, List[ProcessGroup]],
+        model_shard_infer_config: ModelShardInferenceConfig = None,
+    ):
+        super().__init__()
+        self.module = module
+        self.rank = dist.get_rank(group=process_group)
+        self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+
+    @staticmethod
+    def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
+        model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+        distrifusion_embed = DistrifusionPatchEmbed(
+            module, process_group, model_shard_infer_config=model_shard_infer_config
+        )
+        return distrifusion_embed
+
+    def forward(self, latent):
+        module = self.module
+        if module.pos_embed_max_size is not None:
+            height, width = latent.shape[-2:]
+        else:
+            height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size
+
+        latent = module.proj(latent)
+        if module.flatten:
+            latent = latent.flatten(2).transpose(1, 2)  # BCHW -> BNC
+        if module.layer_norm:
+            latent = module.norm(latent)
+        if module.pos_embed is None:
+            return latent.to(latent.dtype)
+        # Interpolate or crop positional embeddings as needed
+        if module.pos_embed_max_size:
+            pos_embed = module.cropped_pos_embed(height, width)
+        else:
+            if module.height != height or module.width != width:
+                pos_embed = get_2d_sincos_pos_embed(
+                    embed_dim=module.pos_embed.shape[-1],
+                    grid_size=(height, width),
+                    base_size=module.base_size,
+                    interpolation_scale=module.interpolation_scale,
+                )
+                pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
+            else:
+                pos_embed = module.pos_embed
+
+        b, c, h = pos_embed.shape
+        pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank]
+
+        return (latent + pos_embed).to(latent.dtype)
+
+
+# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py
+class DistrifusionConv2D(ParallelModule):
+
+    def __init__(
+        self,
+        module: nn.Conv2d,
+        process_group: Union[ProcessGroup, List[ProcessGroup]],
+        model_shard_infer_config: ModelShardInferenceConfig = None,
+    ):
+        super().__init__()
+        self.module = module
+        self.rank = dist.get_rank(group=process_group)
+        self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+
+    @staticmethod
+    def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
+        model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+        distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config)
+        return distrifusion_conv
+
+    def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
+
+        b, c, h, w = x.shape
+
+        stride = self.module.stride[0]
+        padding = self.module.padding[0]
+
+        output_h = x.shape[2] // stride // self.patched_parallelism_size
+        idx = dist.get_rank()
+        h_begin = output_h * idx * stride - padding
+        h_end = output_h * (idx + 1) * stride + padding
+        final_padding = [padding, padding, 0, 0]
+        if h_begin < 0:
+            h_begin = 0
+            final_padding[2] = padding
+        if h_end > h:
+            h_end = h
+            final_padding[3] = padding
+        sliced_input = x[:, :, h_begin:h_end, :]
+        padded_input = F.pad(sliced_input, final_padding, mode="constant")
+        return F.conv2d(
+            padded_input,
+            self.module.weight,
+            self.module.bias,
+            stride=stride,
+            padding="valid",
+        )
+
+    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        output = self.sliced_forward(input)
+        return output
+
+
+# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py
+class DistrifusionFusedAttention(ParallelModule):
+
+    def __init__(
+        self,
+        module: attention_processor.Attention,
+        process_group: Union[ProcessGroup, List[ProcessGroup]],
+        model_shard_infer_config: ModelShardInferenceConfig = None,
+    ):
+        super().__init__()
+        self.counter = 0
+        self.module = module
+        self.buffer_list = None
+        self.kv_buffer_idx = dist.get_rank(group=process_group)
+        self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+        self.handle = None
+        self.process_group = process_group
+        self.warm_step = 5  # for warmup
+
+    @staticmethod
+    def from_native_module(
+        module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+    ) -> ParallelModule:
+        model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+        return DistrifusionFusedAttention(
+            module=module,
+            process_group=process_group,
+            model_shard_infer_config=model_shard_infer_config,
+        )
+
+    def _forward(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: torch.FloatTensor = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        *args,
+        **kwargs,
+    ) -> torch.FloatTensor:
+        residual = hidden_states
+
+        input_ndim = hidden_states.ndim
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+        context_input_ndim = encoder_hidden_states.ndim
+        if context_input_ndim == 4:
+            batch_size, channel, height, width = encoder_hidden_states.shape
+            encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size = encoder_hidden_states.shape[0]
+
+        # `sample` projections.
+        query = attn.to_q(hidden_states)
+        key = attn.to_k(hidden_states)
+        value = attn.to_v(hidden_states)
+
+        kv = torch.cat([key, value], dim=-1)  # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
+
+        if self.patched_parallelism_size == 1:
+            full_kv = kv
+        else:
+            if self.buffer_list is None:  # buffer not created
+                full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
+            elif self.counter <= self.warm_step:
+                # logger.info(f"warmup: {self.counter}")
+                dist.all_gather(
+                    self.buffer_list,
+                    kv,
+                    group=self.process_group,
+                    async_op=False,
+                )
+                full_kv = torch.cat(self.buffer_list, dim=1)
+            else:
+                # logger.info(f"use old kv to infer: {self.counter}")
+                self.buffer_list[self.kv_buffer_idx].copy_(kv)
+                full_kv = torch.cat(self.buffer_list, dim=1)
+                assert self.handle is None, "we should maintain the kv of last step"
+                self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
+
+        key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
+
+        # `context` projections.
+        encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+        # attention
+        query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+        key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+        value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        hidden_states = hidden_states = F.scaled_dot_product_attention(
+            query, key, value, dropout_p=0.0, is_causal=False
+        )  # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        # Split the attention outputs.
+        hidden_states, encoder_hidden_states = (
+            hidden_states[:, : residual.shape[1]],
+            hidden_states[:, residual.shape[1] :],
+        )
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+        if not attn.context_pre_only:
+            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+        if context_input_ndim == 4:
+            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        return hidden_states, encoder_hidden_states
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        **cross_attention_kwargs,
+    ) -> torch.Tensor:
+
+        if self.handle is not None:
+            self.handle.wait()
+            self.handle = None
+
+        b, l, c = hidden_states.shape
+        kv_shape = (b, l, self.module.to_k.out_features * 2)
+        if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
+
+            self.buffer_list = [
+                torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
+                for _ in range(self.patched_parallelism_size)
+            ]
+
+            self.counter = 0
+
+        attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys())
+        quiet_attn_parameters = {"ip_adapter_masks"}
+        unused_kwargs = [
+            k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
+        ]
+        if len(unused_kwargs) > 0:
+            logger.warning(
+                f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored."
+            )
+        cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+        output = self._forward(
+            self.module,
+            hidden_states,
+            encoder_hidden_states=encoder_hidden_states,
+            attention_mask=attention_mask,
+            **cross_attention_kwargs,
+        )
+
+        self.counter += 1
+
+        return output
+
+
+# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py
+class DistriSelfAttention(ParallelModule):
+    def __init__(
+        self,
+        module: Attention,
+        process_group: Union[ProcessGroup, List[ProcessGroup]],
+        model_shard_infer_config: ModelShardInferenceConfig = None,
+    ):
+        super().__init__()
+        self.counter = 0
+        self.module = module
+        self.buffer_list = None
+        self.kv_buffer_idx = dist.get_rank(group=process_group)
+        self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
+        self.handle = None
+        self.process_group = process_group
+        self.warm_step = 3  # for warmup
+
+    @staticmethod
+    def from_native_module(
+        module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+    ) -> ParallelModule:
+        model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
+        return DistriSelfAttention(
+            module=module,
+            process_group=process_group,
+            model_shard_infer_config=model_shard_infer_config,
+        )
+
+    def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
+        attn = self.module
+        assert isinstance(attn, Attention)
+
+        residual = hidden_states
+
+        batch_size, sequence_length, _ = hidden_states.shape
+
+        query = attn.to_q(hidden_states)
+
+        encoder_hidden_states = hidden_states
+        k = self.module.to_k(encoder_hidden_states)
+        v = self.module.to_v(encoder_hidden_states)
+        kv = torch.cat([k, v], dim=-1)  # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
+
+        if self.patched_parallelism_size == 1:
+            full_kv = kv
+        else:
+            if self.buffer_list is None:  # buffer not created
+                full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
+            elif self.counter <= self.warm_step:
+                # logger.info(f"warmup: {self.counter}")
+                dist.all_gather(
+                    self.buffer_list,
+                    kv,
+                    group=self.process_group,
+                    async_op=False,
+                )
+                full_kv = torch.cat(self.buffer_list, dim=1)
+            else:
+                # logger.info(f"use old kv to infer: {self.counter}")
+                self.buffer_list[self.kv_buffer_idx].copy_(kv)
+                full_kv = torch.cat(self.buffer_list, dim=1)
+                assert self.handle is None, "we should maintain the kv of last step"
+                self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
+
+        if HAS_FLASH_ATTN:
+            # flash attn
+            key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
+            inner_dim = key.shape[-1]
+            head_dim = inner_dim // attn.heads
+
+            query = query.view(batch_size, -1, attn.heads, head_dim)
+            key = key.view(batch_size, -1, attn.heads, head_dim)
+            value = value.view(batch_size, -1, attn.heads, head_dim)
+
+            hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
+            hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
+        else:
+            # naive attn
+            key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
+
+            inner_dim = key.shape[-1]
+            head_dim = inner_dim // attn.heads
+
+            query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+            key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+            value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+            # the output of sdp = (batch, num_heads, seq_len, head_dim)
+            # TODO: add support for attn.scale when we move to Torch 2.1
+            hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+
+            hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+            hidden_states = hidden_states.to(query.dtype)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        scale: float = 1.0,
+        *args,
+        **kwargs,
+    ) -> torch.FloatTensor:
+
+        # async preallocates memo buffer
+        if self.handle is not None:
+            self.handle.wait()
+            self.handle = None
+
+        b, l, c = hidden_states.shape
+        kv_shape = (b, l, self.module.to_k.out_features * 2)
+        if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
+
+            self.buffer_list = [
+                torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
+                for _ in range(self.patched_parallelism_size)
+            ]
+
+            self.counter = 0
+
+        output = self._forward(hidden_states, scale=scale)
+
+        self.counter += 1
+        return output
diff --git a/colossalai/inference/modeling/models/pixart_alpha.py b/colossalai/inference/modeling/models/pixart_alpha.py
index d5774946e..cc2bee5ef 100644
--- a/colossalai/inference/modeling/models/pixart_alpha.py
+++ b/colossalai/inference/modeling/models/pixart_alpha.py
@@ -14,7 +14,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retri
 
 from colossalai.logging import get_dist_logger
 
-from .diffusion import DiffusionPipe
+from ..layers.diffusion import DiffusionPipe
 
 logger = get_dist_logger(__name__)
 
diff --git a/colossalai/inference/modeling/models/stablediffusion3.py b/colossalai/inference/modeling/models/stablediffusion3.py
index d1c63a6dc..b12316403 100644
--- a/colossalai/inference/modeling/models/stablediffusion3.py
+++ b/colossalai/inference/modeling/models/stablediffusion3.py
@@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
 import torch
 from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
 
-from .diffusion import DiffusionPipe
+from ..layers.diffusion import DiffusionPipe
 
 
 # TODO(@lry89757) temporarily image, please support more return output
diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py
index 356056ba7..1150b2432 100644
--- a/colossalai/inference/modeling/policy/pixart_alpha.py
+++ b/colossalai/inference/modeling/policy/pixart_alpha.py
@@ -1,9 +1,17 @@
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
 from torch import nn
 
 from colossalai.inference.config import RPC_PARAM
-from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.distrifusion import (
+    DistrifusionConv2D,
+    DistrifusionPatchEmbed,
+    DistriSelfAttention,
+    PixArtAlphaTransformer2DModel_forward,
+)
 from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
-from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
 
 
 class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
@@ -12,9 +20,46 @@ class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
 
     def module_policy(self):
         policy = {}
+
+        if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
+
+            policy[PixArtTransformer2DModel] = ModulePolicyDescription(
+                sub_module_replacement=[
+                    SubModuleReplacementDescription(
+                        suffix="pos_embed.proj",
+                        target_module=DistrifusionConv2D,
+                        kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+                    ),
+                    SubModuleReplacementDescription(
+                        suffix="pos_embed",
+                        target_module=DistrifusionPatchEmbed,
+                        kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+                    ),
+                ],
+                attribute_replacement={
+                    "patched_parallel_size": self.shard_config.extra_kwargs[
+                        "model_shard_infer_config"
+                    ].patched_parallelism_size
+                },
+                method_replacement={"forward": PixArtAlphaTransformer2DModel_forward},
+            )
+
+            policy[BasicTransformerBlock] = ModulePolicyDescription(
+                sub_module_replacement=[
+                    SubModuleReplacementDescription(
+                        suffix="attn1",
+                        target_module=DistriSelfAttention,
+                        kwargs={
+                            "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
+                        },
+                    )
+                ]
+            )
+
         self.append_or_create_method_replacement(
             description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
         )
+
         return policy
 
     def preprocess(self) -> nn.Module:
diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py
index c9877f7dc..39b764b92 100644
--- a/colossalai/inference/modeling/policy/stablediffusion3.py
+++ b/colossalai/inference/modeling/policy/stablediffusion3.py
@@ -1,9 +1,17 @@
+from diffusers.models.attention import JointTransformerBlock
+from diffusers.models.transformers import SD3Transformer2DModel
 from torch import nn
 
 from colossalai.inference.config import RPC_PARAM
-from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
+from colossalai.inference.modeling.layers.distrifusion import (
+    DistrifusionConv2D,
+    DistrifusionFusedAttention,
+    DistrifusionPatchEmbed,
+    SD3Transformer2DModel_forward,
+)
 from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
-from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
 
 
 class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
@@ -12,6 +20,42 @@ class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
 
     def module_policy(self):
         policy = {}
+
+        if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
+
+            policy[SD3Transformer2DModel] = ModulePolicyDescription(
+                sub_module_replacement=[
+                    SubModuleReplacementDescription(
+                        suffix="pos_embed.proj",
+                        target_module=DistrifusionConv2D,
+                        kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+                    ),
+                    SubModuleReplacementDescription(
+                        suffix="pos_embed",
+                        target_module=DistrifusionPatchEmbed,
+                        kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
+                    ),
+                ],
+                attribute_replacement={
+                    "patched_parallel_size": self.shard_config.extra_kwargs[
+                        "model_shard_infer_config"
+                    ].patched_parallelism_size
+                },
+                method_replacement={"forward": SD3Transformer2DModel_forward},
+            )
+
+        policy[JointTransformerBlock] = ModulePolicyDescription(
+            sub_module_replacement=[
+                SubModuleReplacementDescription(
+                    suffix="attn",
+                    target_module=DistrifusionFusedAttention,
+                    kwargs={
+                        "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
+                    },
+                )
+            ]
+        )
+
         self.append_or_create_method_replacement(
             description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
         )
diff --git a/examples/inference/stable_diffusion/README.md b/examples/inference/stable_diffusion/README.md
new file mode 100644
index 000000000..c11b98043
--- /dev/null
+++ b/examples/inference/stable_diffusion/README.md
@@ -0,0 +1,22 @@
+## File Structure
+```
+|- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model.
+|- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion
+|- benchmark_sd3.py: benchmark the performance of our InferenceEngine
+|- run_benchmark.sh: run benchmark command
+```
+Note: compute_metric.py need some dependencies which need `pip install -r requirements.txt`, `requirements.txt` is in `examples/inference/stable_diffusion/`
+
+## Run Inference
+
+The provided example `sd3_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `DiffusionPipeline` as model class, and the script is good to run inference with StableDiffusion 3.
+
+For a basic setting, you could run the example by:
+```bash
+colossalai run --nproc_per_node 1 sd3_generation.py -m PATH_MODEL -p "hello world"
+```
+
+Run multi-GPU inference (Patched Parallelism), as in the following example using 2 GPUs:
+```bash
+colossalai run --nproc_per_node 2 sd3_generation.py -m PATH_MODEL
+```
diff --git a/examples/inference/stable_diffusion/benchmark_sd3.py b/examples/inference/stable_diffusion/benchmark_sd3.py
new file mode 100644
index 000000000..19db57c33
--- /dev/null
+++ b/examples/inference/stable_diffusion/benchmark_sd3.py
@@ -0,0 +1,179 @@
+import argparse
+import json
+import time
+from contextlib import nullcontext
+
+import torch
+import torch.distributed as dist
+from diffusers import DiffusionPipeline
+
+import colossalai
+from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
+from colossalai.inference.core.engine import InferenceEngine
+from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
+
+GIGABYTE = 1024**3
+MEGABYTE = 1024 * 1024
+
+_DTYPE_MAPPING = {
+    "fp16": torch.float16,
+    "bf16": torch.bfloat16,
+    "fp32": torch.float32,
+}
+
+
+def log_generation_time(log_data, log_file):
+    with open(log_file, "a") as f:
+        json.dump(log_data, f, indent=2)
+        f.write("\n")
+
+
+def warmup(engine, args):
+    for _ in range(args.n_warm_up_steps):
+        engine.generate(
+            prompts=["hello world"],
+            generation_config=DiffusionGenerationConfig(
+                num_inference_steps=args.num_inference_steps, height=args.height[0], width=args.width[0]
+            ),
+        )
+
+
+def profile_context(args):
+    return (
+        torch.profiler.profile(
+            record_shapes=True,
+            with_stack=True,
+            with_modules=True,
+            activities=[
+                torch.profiler.ProfilerActivity.CPU,
+                torch.profiler.ProfilerActivity.CUDA,
+            ],
+        )
+        if args.profile
+        else nullcontext()
+    )
+
+
+def log_and_profile(h, w, avg_time, log_msg, args, model_name, mode, prof=None):
+    log_data = {
+        "mode": mode,
+        "model": model_name,
+        "batch_size": args.batch_size,
+        "patched_parallel_size": args.patched_parallel_size,
+        "num_inference_steps": args.num_inference_steps,
+        "height": h,
+        "width": w,
+        "dtype": args.dtype,
+        "profile": args.profile,
+        "n_warm_up_steps": args.n_warm_up_steps,
+        "n_repeat_times": args.n_repeat_times,
+        "avg_generation_time": avg_time,
+        "log_message": log_msg,
+    }
+
+    if args.log:
+        log_file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}.json"
+        log_generation_time(log_data=log_data, log_file=log_file)
+
+    if args.profile:
+        file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}_prof.json"
+        prof.export_chrome_trace(file)
+
+
+def benchmark_colossalai(rank, world_size, port, args):
+    colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    from colossalai.cluster.dist_coordinator import DistCoordinator
+
+    coordinator = DistCoordinator()
+
+    inference_config = InferenceConfig(
+        dtype=args.dtype,
+        patched_parallelism_size=args.patched_parallel_size,
+    )
+    engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False)
+
+    warmup(engine, args)
+
+    for h, w in zip(args.height, args.width):
+        with profile_context(args) as prof:
+            start = time.perf_counter()
+            for _ in range(args.n_repeat_times):
+                engine.generate(
+                    prompts=["hello world"],
+                    generation_config=DiffusionGenerationConfig(
+                        num_inference_steps=args.num_inference_steps, height=h, width=w
+                    ),
+                )
+            end = time.perf_counter()
+
+        avg_time = (end - start) / args.n_repeat_times
+        log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
+        coordinator.print_on_master(log_msg)
+
+        if dist.get_rank() == 0:
+            log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "colossalai", prof=prof)
+
+
+def benchmark_diffusers(args):
+    model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda")
+
+    for _ in range(args.n_warm_up_steps):
+        model(
+            prompt="hello world",
+            num_inference_steps=args.num_inference_steps,
+            height=args.height[0],
+            width=args.width[0],
+        )
+
+    for h, w in zip(args.height, args.width):
+        with profile_context(args) as prof:
+            start = time.perf_counter()
+            for _ in range(args.n_repeat_times):
+                model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w)
+            end = time.perf_counter()
+
+        avg_time = (end - start) / args.n_repeat_times
+        log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
+        print(log_msg)
+
+        log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "diffusers", prof)
+
+
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def benchmark(args):
+    if args.mode == "colossalai":
+        spawn(benchmark_colossalai, nprocs=args.patched_parallel_size, args=args)
+    elif args.mode == "diffusers":
+        benchmark_diffusers(args)
+
+
+"""
+# enable log
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log
+
+# enable profiler
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "stabilityai/stable-diffusion-3-medium-diffusers" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
+python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
+"""
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
+    parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size")
+    parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps")
+    parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024, 2048], help="Height list")
+    parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024, 2048], help="Width list")
+    parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type")
+    parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps")
+    parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times")
+    parser.add_argument("--profile", default=False, action="store_true", help="Enable torch profiler")
+    parser.add_argument("--log", default=False, action="store_true", help="Enable logging")
+    parser.add_argument("-m", "--model", default="stabilityai/stable-diffusion-3-medium-diffusers", help="Model path")
+    parser.add_argument(
+        "--mode", default="colossalai", choices=["colossalai", "diffusers"], help="Inference framework mode"
+    )
+    args = parser.parse_args()
+    benchmark(args)
diff --git a/examples/inference/stable_diffusion/compute_metric.py b/examples/inference/stable_diffusion/compute_metric.py
new file mode 100644
index 000000000..14c92501b
--- /dev/null
+++ b/examples/inference/stable_diffusion/compute_metric.py
@@ -0,0 +1,80 @@
+# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py
+import argparse
+import os
+
+import numpy as np
+import torch
+from cleanfid import fid
+from PIL import Image
+from torch.utils.data import DataLoader, Dataset
+from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
+from torchvision.transforms import Resize
+from tqdm import tqdm
+
+
+def read_image(path: str):
+    """
+    input: path
+    output: tensor (C, H, W)
+    """
+    img = np.asarray(Image.open(path))
+    if len(img.shape) == 2:
+        img = np.repeat(img[:, :, None], 3, axis=2)
+    img = torch.from_numpy(img).permute(2, 0, 1)
+    return img
+
+
+class MultiImageDataset(Dataset):
+    def __init__(self, root0, root1, is_gt=False):
+        super().__init__()
+        self.root0 = root0
+        self.root1 = root1
+        file_names0 = os.listdir(root0)
+        file_names1 = os.listdir(root1)
+
+        self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
+        self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
+        self.is_gt = is_gt
+        assert len(self.image_names0) == len(self.image_names1)
+
+    def __len__(self):
+        return len(self.image_names0)
+
+    def __getitem__(self, idx):
+        img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
+        if self.is_gt:
+            # resize to 1024 x 1024
+            img0 = Resize((1024, 1024))(img0)
+        img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
+
+        batch_list = [img0, img1]
+        return batch_list
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--batch_size", type=int, default=64)
+    parser.add_argument("--num_workers", type=int, default=8)
+    parser.add_argument("--is_gt", action="store_true")
+    parser.add_argument("--input_root0", type=str, required=True)
+    parser.add_argument("--input_root1", type=str, required=True)
+    args = parser.parse_args()
+
+    psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
+    lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
+
+    dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
+    dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
+
+    progress_bar = tqdm(dataloader)
+    with torch.inference_mode():
+        for i, batch in enumerate(progress_bar):
+            batch = [img.to("cuda") / 255 for img in batch]
+            batch_size = batch[0].shape[0]
+            psnr.update(batch[0], batch[1])
+            lpips.update(batch[0], batch[1])
+    fid_score = fid.compute_fid(args.input_root0, args.input_root1)
+
+    print("PSNR:", psnr.compute().item())
+    print("LPIPS:", lpips.compute().item())
+    print("FID:", fid_score)
diff --git a/examples/inference/stable_diffusion/requirements.txt b/examples/inference/stable_diffusion/requirements.txt
new file mode 100644
index 000000000..c4e74162d
--- /dev/null
+++ b/examples/inference/stable_diffusion/requirements.txt
@@ -0,0 +1,3 @@
+torchvision
+torchmetrics
+cleanfid
diff --git a/examples/inference/stable_diffusion/run_benchmark.sh b/examples/inference/stable_diffusion/run_benchmark.sh
new file mode 100644
index 000000000..f3e45a335
--- /dev/null
+++ b/examples/inference/stable_diffusion/run_benchmark.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers")
+parallelism=(1 2 4 8)
+resolutions=(1024 2048 3840)
+modes=("colossalai" "diffusers")
+
+CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
+    local n=${1:-"9999"}
+    echo "GPU Memory Usage:"
+    local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
+        | tail -n +2 \
+        | nl -v 0 \
+        | tee /dev/tty \
+        | sort -g -k 2 \
+        | awk '{print $1}' \
+        | head -n $n)
+    export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+    echo "Now CUDA_VISIBLE_DEVICES is set to:"
+    echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+for model in "${models[@]}"; do
+    for p in "${parallelism[@]}"; do
+        for resolution in "${resolutions[@]}"; do
+            for mode in "${modes[@]}"; do
+                if [[ "$mode" == "colossalai" && "$p" == 1 ]]; then
+                    continue
+                fi
+                if [[ "$mode" == "diffusers" && "$p" != 1 ]]; then
+                    continue
+                fi
+                CUDA_VISIBLE_DEVICES_set_n_least_memory_usage $p
+
+                cmd="python examples/inference/stable_diffusion/benchmark_sd3.py -m \"$model\" -p $p --mode $mode --log -H $resolution -w $resolution"
+
+                echo "Executing: $cmd"
+                eval $cmd
+            done
+        done
+    done
+done
diff --git a/examples/inference/stable_diffusion/sd3_generation.py b/examples/inference/stable_diffusion/sd3_generation.py
index fe989eed7..9e146c34b 100644
--- a/examples/inference/stable_diffusion/sd3_generation.py
+++ b/examples/inference/stable_diffusion/sd3_generation.py
@@ -1,18 +1,17 @@
 import argparse
 
-from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
-from torch import bfloat16, float16, float32
+from diffusers import DiffusionPipeline
+from torch import bfloat16
+from torch import distributed as dist
+from torch import float16, float32
 
 import colossalai
 from colossalai.cluster import DistCoordinator
 from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
 from colossalai.inference.core.engine import InferenceEngine
-from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
-from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
 
 # For Stable Diffusion 3, we'll use the following configuration
-MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
-POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
+MODEL_CLS = DiffusionPipeline
 
 TORCH_DTYPE_MAP = {
     "fp16": float16,
@@ -43,20 +42,27 @@ def infer(args):
         max_batch_size=args.max_batch_size,
         tp_size=args.tp_size,
         use_cuda_kernel=args.use_cuda_kernel,
+        patched_parallelism_size=dist.get_world_size(),
     )
-    engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
+    engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
 
     # ==============================
     # Generation
     # ==============================
     coordinator.print_on_master(f"Generating...")
     out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
-    out.save("cat.jpg")
+    if dist.get_rank() == 0:
+        out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
     coordinator.print_on_master(out)
 
 
 # colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
+
 # colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
+# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
+
+# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
+# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
 
 
 if __name__ == "__main__":