diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py
index e114e8a61..1beb86874 100644
--- a/colossalai/inference/config.py
+++ b/colossalai/inference/config.py
@@ -5,7 +5,7 @@ Our config contains various options for inference optimization, it is a unified
 import logging
 from abc import ABC, abstractmethod
 from dataclasses import dataclass, fields
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
 
 import torch
 from transformers.generation import GenerationConfig
@@ -396,3 +396,49 @@ class ModelShardInferenceConfig:
     use_cuda_kernel: bool = False
     use_spec_dec: bool = False
     use_flash_attn: bool = False
+
+
+@dataclass
+class DiffusionGenerationConfig:
+    """
+    Param for diffusion model forward
+    """
+
+    prompt_2: Optional[Union[str, List[str]]] = None
+    prompt_3: Optional[Union[str, List[str]]] = None
+    height: Optional[int] = None
+    width: Optional[int] = None
+    num_inference_steps: int = None
+    timesteps: List[int] = None
+    guidance_scale: float = None
+    negative_prompt: Optional[Union[str, List[str]]] = (
+        None  # NOTE(@lry89757) in pixart default to "", in sd3 default to None
+    )
+    negative_prompt_2: Optional[Union[str, List[str]]] = None
+    negative_prompt_3: Optional[Union[str, List[str]]] = None
+    num_images_per_prompt: Optional[int] = None
+    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None
+    latents: Optional[torch.FloatTensor] = None
+    prompt_embeds: Optional[torch.FloatTensor] = None
+    negative_prompt_embeds: Optional[torch.FloatTensor] = None
+    pooled_prompt_embeds: Optional[torch.FloatTensor] = None
+    negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None
+    output_type: Optional[str] = None  # "pil"
+    return_dict: bool = None
+    joint_attention_kwargs: Optional[Dict[str, Any]] = None
+    clip_skip: Optional[int] = None
+    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
+    callback_on_step_end_tensor_inputs: List[str] = None
+
+    def to_dict(self) -> Dict[str, Any]:
+        # NOTE(@lry89757) Only return the dict that not the default value None
+        result = {}
+        for field in fields(self):
+            value = getattr(self, field.name)
+            if value is not None:
+                result[field.name] = value
+        return result
+
+    @classmethod
+    def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig":
+        return cls(**kwargs)
diff --git a/colossalai/inference/core/base_engine.py b/colossalai/inference/core/base_engine.py
new file mode 100644
index 000000000..392dd2990
--- /dev/null
+++ b/colossalai/inference/core/base_engine.py
@@ -0,0 +1,90 @@
+from abc import ABC, abstractmethod
+
+import torch
+import torch.nn as nn
+
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.inference.config import ModelShardInferenceConfig
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer.policies.base_policy import Policy
+
+
+class BaseEngine(ABC):
+    @abstractmethod
+    def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None):
+        pass
+
+    @abstractmethod
+    def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None):
+        """
+        Init Model for Engine
+        """
+
+    @abstractmethod
+    def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs):
+        """
+        Generate ouptput for coming requests
+        """
+
+    @abstractmethod
+    def add_request(self, prompts, request_ids=None, **kwargs):
+        """
+        Add new request to Engine
+        """
+
+    @abstractmethod
+    def step(self):
+        """
+        Perform one new step forward
+        """
+
+    @abstractmethod
+    def _verify_args(self):
+        """
+        Verify the parameters and members of class
+        """
+
+    @torch.inference_mode()
+    def capture_model(self):
+        """
+        Use cuda graph to capture model
+        """
+        return NotImplementedError("This method should be implemented by subclasses")
+
+    def _shardformer(
+        self,
+        model: nn.Module,
+        model_policy: Policy,
+        model_shard_infer_config: ModelShardInferenceConfig = None,
+        stage_manager: PipelineStageManager = None,
+        tp_group: ProcessGroupMesh = None,
+        **kwargs,
+    ) -> nn.Module:
+        """
+        Initialize ShardConfig and replace the model with shardformer.
+
+        Args:
+            model (nn.Module): Path or nn.Module of this model.
+            model_policy (Policy): The policy to shardformer model which is determined by the model type.
+            stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
+            tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
+
+        Returns:
+            nn.Module: The model optimized by Shardformer.
+        """
+
+        shardconfig = ShardConfig(
+            tensor_parallel_process_group=tp_group,
+            pipeline_stage_manager=stage_manager,
+            enable_tensor_parallelism=(self.inference_config.tp_size > 1),
+            enable_fused_normalization=False,
+            enable_all_optimization=False,
+            enable_flash_attention=False,
+            enable_jit_fused=False,
+            enable_sequence_parallelism=False,
+            extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs},
+        )
+        shardformer = ShardFormer(shard_config=shardconfig)
+        shard_model, _ = shardformer.optimize(model, model_policy)
+        return shard_model
diff --git a/colossalai/inference/core/diffusion_engine.py b/colossalai/inference/core/diffusion_engine.py
new file mode 100644
index 000000000..75b9889bf
--- /dev/null
+++ b/colossalai/inference/core/diffusion_engine.py
@@ -0,0 +1,200 @@
+from itertools import count
+from typing import List, Tuple, Type, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn as nn
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from torch import distributed as dist
+
+from colossalai.accelerator import get_accelerator
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
+from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.policy import model_policy_map
+from colossalai.inference.struct import DiffusionSequence
+from colossalai.inference.utils import get_model_size, get_model_type
+from colossalai.logging import get_dist_logger
+from colossalai.shardformer.policies.base_policy import Policy
+
+from .base_engine import BaseEngine
+from .request_handler import NaiveRequestHandler
+
+PP_AXIS, TP_AXIS = 0, 1
+
+
+class DiffusionEngine(BaseEngine):
+    def __init__(
+        self,
+        model_or_path: DiffusionPipeline | str,
+        inference_config: InferenceConfig = None,
+        verbose: bool = False,
+        model_policy: Policy | type[Policy] = None,
+    ) -> None:
+        self.inference_config = inference_config
+        self.dtype = inference_config.dtype
+        self.high_precision = inference_config.high_precision
+
+        self.verbose = verbose
+        self.logger = get_dist_logger(__name__)
+        self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
+
+        self.model_type = get_model_type(model_or_path=model_or_path)
+
+        self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
+
+        self.request_handler = NaiveRequestHandler()
+
+        self.counter = count()
+
+        self._verify_args()
+
+    def _verify_args(self) -> None:
+        assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe"
+
+    def init_model(
+        self,
+        model_or_path: Union[str, nn.Module, DiffusionPipeline],
+        model_policy: Union[Policy, Type[Policy]] = None,
+        model_shard_infer_config: ModelShardInferenceConfig = None,
+    ):
+        """
+        Shard model or/and Load weight
+
+        Args:
+            model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
+            model_policy (Policy): the policy to replace the model.
+            model_inference_config: the configuration for modeling initialization when inference.
+            model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
+        """
+        if isinstance(model_or_path, str):
+            model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype)
+            policy_map_key = model.__class__.__name__
+            model = DiffusionPipe(model)
+        elif isinstance(model_or_path, DiffusionPipeline):
+            policy_map_key = model_or_path.__class__.__name__
+            model = DiffusionPipe(model_or_path)
+        else:
+            self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!")
+
+        torch.cuda.empty_cache()
+        init_gpu_memory = torch.cuda.mem_get_info()[0]
+
+        self.device = get_accelerator().get_current_device()
+        if self.verbose:
+            self.logger.info(f"the device is {self.device}")
+
+        if self.verbose:
+            self.logger.info(
+                f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
+            )
+
+        if model_policy is None:
+            model_policy = model_policy_map.get(policy_map_key)
+
+        if not isinstance(model_policy, Policy):
+            try:
+                model_policy = model_policy()
+            except Exception as e:
+                raise ValueError(f"Unable to instantiate model policy: {e}")
+
+        assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
+        pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
+        tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
+
+        self.model = self._shardformer(
+            model,
+            model_policy,
+            model_shard_infer_config,
+            None,
+            tp_group=tp_group,
+        )
+
+        self.model = model.to(self.device)
+
+        if self.verbose:
+            self.logger.info(
+                f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
+            )
+
+        free_gpu_memory, _ = torch.cuda.mem_get_info()
+        peak_memory = init_gpu_memory - free_gpu_memory
+        if self.verbose:
+            self.logger.info(
+                f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
+            )
+
+    def generate(
+        self,
+        request_ids: Union[List[int], int] = None,
+        prompts: Union[List[str], str] = None,
+        generation_config: DiffusionGenerationConfig = None,
+        **kwargs,
+    ) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
+        """ """
+        gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
+        prompts = [prompts] if isinstance(prompts, str) else prompts
+        request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
+
+        with torch.inference_mode():
+            if prompts is not None:
+                self.add_request(
+                    request_ids=request_ids,
+                    prompts=prompts,
+                    **gen_config_dict,
+                    **kwargs,
+                )
+
+            output_reqs_list = []
+
+            # intuition: If user provide a generation config, we should replace the existing one.
+            if generation_config is not None:
+                self.generation_config = generation_config
+                self.generation_config_dict = gen_config_dict
+
+            while self.request_handler.check_unfinished_reqs():
+                output_reqs_list += self.step()
+
+            return output_reqs_list
+
+    def add_request(
+        self,
+        prompts: Union[List[str], str],
+        request_ids: Union[List[int], int] = None,
+        **kwargs,
+    ):
+        if request_ids is not None and not isinstance(request_ids, list):
+            request_ids = [request_ids]
+
+        if not isinstance(prompts, list):
+            prompts = [prompts]
+
+        generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs)
+        prompts_num = len(prompts)
+        for i in range(prompts_num):
+            if request_ids:
+                assert isinstance(
+                    request_ids[0], int
+                ), f"The request_id type must be int, but got {type(request_ids[0])}"
+                assert len(request_ids) == prompts_num
+                request_id = request_ids[i]
+            else:
+                request_id = next(self.counter)
+
+            seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config)
+
+            self.request_handler.add_sequence(seq)
+
+    def step(self) -> List[PIL.Image.Image]:
+        """
+        In each step, do the follows:
+            1. Run RequestHandler.schedule() and get the batch used for inference.
+            2. run forward to get List[Image]
+        Returns:
+            List[PIL.Image.Image]: Image Generated by one step.
+        """
+
+        input = self.request_handler.schedule()
+        ret = self.model(prompt=input.prompt, **input.generation_config.to_dict())
+        return ret
diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py
index 8f8aef65e..5c9bdc321 100644
--- a/colossalai/inference/core/engine.py
+++ b/colossalai/inference/core/engine.py
@@ -1,57 +1,24 @@
-import time
-from itertools import count
-from typing import Dict, List, Optional, Tuple, Type, Union
+from typing import List, Tuple, Type, Union
 
 import numpy as np
-import torch
+import PIL.Image
 import torch.nn as nn
-from torch import distributed as dist
-from transformers import (
-    AutoConfig,
-    AutoModelForCausalLM,
-    GenerationConfig,
-    PreTrainedTokenizer,
-    PreTrainedTokenizerFast,
-)
-from transformers.models.llama.modeling_llama import LlamaForCausalLM
+from diffusers import DiffusionPipeline
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
 
-from colossalai.accelerator import get_accelerator
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.inference.batch_bucket import BatchBucket
-from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
-from colossalai.inference.graph_runner import CUDAGraphRunner
-from colossalai.inference.modeling.policy import model_policy_map
-from colossalai.inference.sampler import search_tokens
-from colossalai.inference.spec import Drafter, GlideInput
-from colossalai.inference.struct import Sequence
-from colossalai.inference.utils import get_model_size, has_index_file
-from colossalai.interface import ModelWrapper
-from colossalai.lazy import LazyInitContext
-from colossalai.logging import get_dist_logger
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.inference.config import InferenceConfig
+from colossalai.inference.utils import ModelType, get_model_type
 from colossalai.shardformer.policies.base_policy import Policy
 
-from .request_handler import RequestHandler
-
 __all__ = ["InferenceEngine"]
 
-PP_AXIS, TP_AXIS = 0, 1
-
-_supported_models = {
-    "LlamaForCausalLM": LlamaForCausalLM,
-    "BaichuanForCausalLM": AutoModelForCausalLM,
-}
-
-_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
-
 
 class InferenceEngine:
     """
     InferenceEngine which manages the inference process..
 
     Args:
-        model_or_path (nn.Module or str): Path or nn.Module of this model.
+        model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model.
         tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
         inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
         verbose (bool): Determine whether or not to log the generation process.
@@ -60,567 +27,68 @@ class InferenceEngine:
 
     def __init__(
         self,
-        model_or_path: Union[nn.Module, str],
-        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
-        inference_config: InferenceConfig,
+        model_or_path: Union[nn.Module, str, DiffusionPipeline],
+        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
+        inference_config: InferenceConfig = None,
         verbose: bool = False,
         model_policy: Union[Policy, Type[Policy]] = None,
     ) -> None:
-        self.inference_config = inference_config
-        self.dtype = inference_config.dtype
-        self.high_precision = inference_config.high_precision
+        self.__dict__["_initialized"] = False  # use __dict__ directly to avoid calling __setattr__
+        self.model_type = get_model_type(model_or_path=model_or_path)
+        self.engine = None
+        if self.model_type == ModelType.LLM:
+            from .llm_engine import LLMEngine
 
-        self.verbose = verbose
-        self.logger = get_dist_logger(__name__)
-        self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
+            self.engine = LLMEngine(
+                model_or_path=model_or_path,
+                tokenizer=tokenizer,
+                inference_config=inference_config,
+                verbose=verbose,
+                model_policy=model_policy,
+            )
+        elif self.model_type == ModelType.DIFFUSION_MODEL:
+            from .diffusion_engine import DiffusionEngine
 
-        self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
-
-        self.generation_config = inference_config.to_generation_config(self.model_config)
-        self.generation_config_dict = self.generation_config.to_dict()
-
-        self.tokenizer = tokenizer
-        self.tokenizer.pad_token = self.tokenizer.eos_token
-
-        self.request_handler = RequestHandler(self.inference_config, self.model_config)
-        self.k_cache, self.v_cache = self.request_handler.get_kvcache()
-        # DISCUSS maybe move this into batch info?
-
-        self.counter = count()
-
-        self.use_cuda_graph = self.inference_config.use_cuda_graph
-        if self.use_cuda_graph:
-            self.graph_runners: Dict[int, CUDAGraphRunner] = {}
-            self.graph_memory_pool = None  # Set during graph capture.
-            if verbose:
-                self.logger.info("Colossal AI CUDA Graph Capture on")
-
-            self.capture_model(self.k_cache, self.v_cache)
-
-        # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
-        self.use_spec_dec = self.inference_config.use_spec_dec
-
-        self.drafter_model = None
-        self.drafter = None
-        self.use_glide = False
-        self.n_spec_tokens = self.inference_config.max_n_spec_tokens
+            self.engine = DiffusionEngine(
+                model_or_path=model_or_path,
+                inference_config=inference_config,
+                verbose=verbose,
+                model_policy=model_policy,
+            )
+        elif self.model_type == ModelType.UNKNOWN:
+            self.logger.error(f"Model Type either Difffusion or LLM!")
 
+        self._initialized = True
         self._verify_args()
 
-    def init_model(
-        self,
-        model_or_path: Union[nn.Module, str],
-        model_policy: Union[Policy, Type[Policy]] = None,
-        model_shard_infer_config: ModelShardInferenceConfig = None,
-    ):
-        """
-        Shard model or/and Load weight
-
-        Args:
-            model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
-            model_policy (Policy): the policy to replace the model.
-            model_inference_config: the configuration for modeling initialization when inference.
-            model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
-        """
-        pretrained_path = None
-        if isinstance(model_or_path, str):
-            import colossalai.interface.pretrained as pretrained_utils
-
-            try:
-                hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
-                arch = getattr(hf_config, "architectures")[0]
-                if arch in _supported_models.keys():
-                    if arch is "BaichuanForCausalLM":
-                        self.logger.warning(
-                            "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
-                        )
-                    ctx = LazyInitContext(default_device="cuda")
-                    with ctx:
-                        model = _supported_models[arch].from_pretrained(
-                            model_or_path, trust_remote_code=True, torch_dtype=self.dtype
-                        )
-                    pretrained_path = pretrained_utils.get_pretrained_path(model)
-                else:
-                    # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
-                    raise ValueError(f"Model {arch} is not supported.")
-
-            except Exception as e:
-                self.logger.error(
-                    f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
-                )
-        else:
-            model = model_or_path
-
-        self.model_config = model.config
-
-        torch.cuda.empty_cache()
-        init_gpu_memory = torch.cuda.mem_get_info()[0]
-
-        self.device = get_accelerator().get_current_device()
-        if self.verbose:
-            self.logger.info(f"the device is {self.device}")
-
-        model = model.to(self.dtype).eval()
-
-        if self.verbose:
-            self.logger.info(
-                f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
-            )
-
-        if model_policy is None:
-            prefix = "nopadding" if not self.inference_config.pad_input else "padding"
-            model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
-            model_policy = model_policy_map.get(model_policy_key)
-
-        if not isinstance(model_policy, Policy):
-            try:
-                model_policy = model_policy()
-            except Exception as e:
-                raise ValueError(f"Unable to instantiate model policy: {e}")
-
-        assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
-        pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
-        tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
-
-        self.model = self._shardformer(
-            model,
-            model_policy,
-            model_shard_infer_config,
-            None,
-            tp_group=tp_group,
-        )
-
-        self.model = ModelWrapper(model).to(self.device)
-
-        if self.verbose:
-            self.logger.info(
-                f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
-            )
-
-        if pretrained_path:
-            from colossalai.inference.core.plugin import InferCheckpoint_io
-
-            cpt_io = InferCheckpoint_io()
-            if_has_index_file, model_index_file = has_index_file(pretrained_path)
-            assert if_has_index_file, "the model path is invalid"
-            cpt_io.load_model(self.model, model_index_file)
-
-        free_gpu_memory, _ = torch.cuda.mem_get_info()
-        peak_memory = init_gpu_memory - free_gpu_memory
-        if self.verbose:
-            self.logger.info(
-                f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
-            )
-
-    @torch.inference_mode()
-    def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
-        assert self.use_cuda_graph, "please turn on the cuda graph"
-
-        if self.verbose:
-            self.logger.info("Colossal AI CUDA Graph Capture begin")
-
-        t_capture_begin = time.perf_counter()
-
-        block_size = self.inference_config.block_size
-        head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
-
-        # Prepare dummy inputs. These will be reused for all batch sizes.
-        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
-        max_context_len_to_capture = self.inference_config.max_context_len_to_capture
-        max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
-        input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
-        # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
-        self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
-        self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
-        self.graph_block_tables[0, :] = np.arange(
-            0, max_num_blocks
-        )  # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
-        block_tables = torch.from_numpy(self.graph_block_tables).cuda()
-        output_tensor = torch.zeros(
-            (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
-        )
-        fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
-
-        max_num_seqs = self.inference_config.max_batch_size
-        batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
-        sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
-        # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
-        sequence_lengths[0] = torch.tensor(
-            self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
-        ).cuda()
-
-        # NOTE: Capturing the largest batch size first may help reduce the
-        # memory usage of CUDA graph.
-        for batch_size in reversed(batch_size_capture_list):
-            if self.verbose:
-                self.logger.info(f"batch size {batch_size} graph capturing")
-
-            input_meta_data = InputMetaData(
-                block_tables=block_tables[:batch_size],
-                sequence_lengths=sequence_lengths[:batch_size],
-                fd_inter_tensor=fd_inter_tensor,
-                batch_size=batch_size,
-                is_prompts=False,
-                use_cuda_graph=True,
-                high_precision=False,
-                kv_seq_len=sequence_lengths[:batch_size].max().item(),
-                head_dim=head_dim,
-                dtype=self.dtype,
-            )
-
-            graph_runner = CUDAGraphRunner(self.model)
-            graph_runner.capture(
-                input_tokens_ids[:batch_size],
-                output_tensor[:batch_size],
-                input_meta_data,
-                k_caches=k_cache,
-                v_caches=v_cache,
-                memory_pool=self.graph_memory_pool,
-            )
-            self.graph_memory_pool = graph_runner.graph.pool()
-            self.graph_runners[batch_size] = graph_runner
-
-        t_capture_end = time.perf_counter()
-
-        if self.verbose:
-            self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
-
     def _verify_args(self) -> None:
         """Verify the input args"""
-        if not isinstance(self.inference_config, InferenceConfig):
-            raise TypeError("Invalid type of inference config provided.")
-        if not isinstance(self.model, nn.Module):
-            raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
-        if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
-            raise TypeError(
-                f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
-            )
-        if isinstance(self.model, ModelWrapper):
-            model = self.model.module
-        assert (
-            model.__class__.__name__ in _supported_models.keys()
-        ), f"Model {self.model.__class__.__name__} is not supported."
-
-    def _shardformer(
-        self,
-        model: nn.Module,
-        model_policy: Policy,
-        model_shard_infer_config: ModelShardInferenceConfig = None,
-        stage_manager: PipelineStageManager = None,
-        tp_group: ProcessGroupMesh = None,
-    ) -> nn.Module:
-        """
-        Initialize ShardConfig and replace the model with shardformer.
-
-        Args:
-            model (nn.Module): Path or nn.Module of this model.
-            model_policy (Policy): The policy to shardformer model which is determined by the model type.
-            stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
-            tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
-
-        Returns:
-            nn.Module: The model optimized by Shardformer.
-        """
-
-        shardconfig = ShardConfig(
-            tensor_parallel_process_group=tp_group,
-            pipeline_stage_manager=stage_manager,
-            enable_tensor_parallelism=(self.inference_config.tp_size > 1),
-            enable_fused_normalization=False,
-            enable_all_optimization=False,
-            enable_flash_attention=False,
-            enable_jit_fused=False,
-            enable_sequence_parallelism=False,
-            extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
-        )
-        shardformer = ShardFormer(shard_config=shardconfig)
-        shard_model, _ = shardformer.optimize(model, model_policy)
-        return shard_model
-
-    def enable_spec_dec(
-        self,
-        drafter_model: nn.Module = None,
-        n_spec_tokens: int = None,
-        use_glide_drafter: bool = False,
-    ) -> None:
-        """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
-
-        Args:
-            drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
-                If provided, the previous drafter and drafter model, if exist, will be overwritten.
-            n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
-                If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
-            use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
-                If True, the drafter model will be replaced by a glide model.
-
-        ```python
-        ...
-        engine = InferenceEngine(model, tokenizer, inference_config)
-
-        engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
-        engine.generate(...)  # Speculative Decoding
-
-        engine.disable_spec_dec()
-        engine.generate(...)  # Normal generation
-
-        engine.enable_spec_dec()
-        engine.generate(...)  # Speculative-Decoding using previously set drafter model and number of spec tokens
-        engine.clear_spec_dec()
-        ```
-        """
-
-        if drafter_model is None and self.drafter is None:
-            raise ValueError("Drafter not initialized. Please provide a Drafter Model")
-        if n_spec_tokens is not None:
-            assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
-            self.n_spec_tokens = n_spec_tokens
-        if drafter_model is not None:
-            assert isinstance(drafter_model, nn.Module)
-            # overwrite the drafter, if exists
-            self.clear_spec_dec()
-            self.drafter_model = drafter_model
-            self.drafter = Drafter(
-                self.drafter_model,
-                self.tokenizer,
-                device=self.device,
-                dtype=self.dtype,
-            )
-
-            # check if the provided drafter model is compatible with GLIDE structure
-            # when `use_glide_drafter` is set to True
-            if (
-                use_glide_drafter
-                and hasattr(drafter_model, "model")
-                and hasattr(drafter_model.model, "layers")
-                and hasattr(drafter_model.model.layers[0], "cross_attn")
-            ):
-                self.use_glide = use_glide_drafter
-            elif use_glide_drafter:
-                self.logger.warning(
-                    f"`use_glide_drafter` is provided as {use_glide_drafter}, "
-                    f"but the provided drafter model is not compatible with GLIDE structure."
-                    f"Falling back to use the default drafter model (non-GLIDE)."
-                )
-        self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
-        # using speculative decoding for subsequent generations
-        self.use_spec_dec = True
-
-    def disable_spec_dec(self) -> None:
-        """Disable using speculative decoding for subsequent generations."""
-        self.request_handler.unset_spec_dec_mode()
-        # set back to the maximum number of tokens to speculate
-        self.n_spec_tokens = self.inference_config.max_n_spec_tokens
-        self.use_glide = False
-        self.use_spec_dec = False
-
-    def clear_spec_dec(self) -> None:
-        """Clear relatable structures of speculative decoding, if exist."""
-        if self.use_spec_dec:
-            self.disable_spec_dec()
-        if self.drafter_model or self.drafter:
-            self.drafter_model = None
-            self.drafter = None
-            torch.cuda.empty_cache()
-        self.use_glide = False
-        self.use_spec_dec = False
-
-    def steps_spec_dec(self) -> List[Sequence]:
-        """
-        Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
-        with many steps of speculating by a drafter model as well as verifying by a main model.
-
-        Returns:
-            List[Sequence]: finished sequences generated by one step.
-        """
-        batch = self.request_handler.schedule()  # prefill batch
-        assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
-
-        input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
-
-        if input_meta_data.use_cuda_graph:
-            model_executable = self.graph_runners[input_meta_data.batch_size]
-        else:
-            model_executable = self.model
-
-        # 1. Prefill small model (Drafter) - fill past kv cache for drafter model
-        # NOTE For glide drafter models, we won't actually apply glide during prefill stage
-        drafter_out = self.drafter.speculate(input_token_ids, 1, None)
-        next_token_ids_spec = drafter_out.next_tokens
-        drafter_past_key_values = drafter_out.past_key_values
-
-        # 2. Prefill main model (Verifier) - fill past kv cache for main model
-        logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
-        next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
-        # append new inputs to the batch, temporarily
-        batch.append_batch_tokens(next_tokens)
-        self.request_handler.allocate_batch_spec_dec(batch, 1)
-        already_allocated_kv_len = batch.seq_lengths[0].item()
-        input_token_ids = batch.get_1D_inputs_spec_dec(1)
-
-        finished_sequences = self.request_handler.update()
-
-        while True:
-            # HACK Retrieve the running batch
-            #      Using RequestHandler.schedule here will re-allocate same kv cache for the batch
-            batch = self.request_handler.running_bb  # running batch
-            assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
-
-            # 3. Decoding - Drafter model speculates `n` tokens
-            glide_input = None
-            if self.use_glide:
-                glide_input = GlideInput(
-                    batch.get_block_table_tensor(),
-                    self.k_cache[-1],  # use kv cahces of the last layer
-                    self.v_cache[-1],
-                    batch.get_sequence_lengths(),
-                    n_spec_tokens=self.n_spec_tokens,
-                )
-
-            drafter_out = self.drafter.speculate(
-                input_token_ids,
-                self.n_spec_tokens,
-                drafter_past_key_values,
-                glide_input=glide_input,
-            )
-            next_token_ids_spec = drafter_out.next_tokens
-            drafter_past_key_values = drafter_out.past_key_values
-            drafter_spec_length = drafter_out.speculated_length
-
-            for next_token_id_spec in next_token_ids_spec:
-                self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
-            cur_length = batch.seq_lengths[0].item()
-            if already_allocated_kv_len < cur_length:
-                self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
-                already_allocated_kv_len = cur_length
-
-            # 4. Decoding - Main model verifies `n` tokens in parallel
-            if drafter_spec_length < batch.num_tokens_to_verify:
-                batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
-            input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
-            logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
-
-            next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
-
-            # 5. Compare and process the results
-            diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
-            n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
-
-            # revoke appended tokens for each Sequence in the current batch
-            batch.revoke_batch_tokens(drafter_spec_length - n_matches)  # revoke drafted tokens
-
-            # append the last correct token generated by the main model
-            self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
-
-            # trim past key values of the drafter model
-            drafter_past_key_values = Drafter.trim_kv_cache(
-                drafter_past_key_values, drafter_spec_length - n_matches - 1
-            )
-
-            # prepare inputs for the next round of speculation
-            n = 1 if n_matches < drafter_spec_length else 2
-            input_token_ids = batch.get_1D_inputs_spec_dec(n)
-
-            self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
-            finished_sequences = self.request_handler.update()
-            if len(finished_sequences) > 0:
-                break
-
-        # Reset back the number of speculated tokens of the batch,
-        # this is used to handle the last round of speculation, in which case the number of speculated tokens
-        # by the drafter is less than the number of speculated tokens set to the engine.
-        batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
-
-        return finished_sequences
+        assert self.engine is not None, "Please init Engine first"
+        assert self._initialized, "Engine must be initialized"
 
     def generate(
         self,
         request_ids: Union[List[int], int] = None,
         prompts: Union[List[str], str] = None,
-        prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
-        return_token_ids: bool = False,
-        generation_config: Optional[GenerationConfig] = None,
-    ) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
+        *args,
+        **kwargs,
+    ) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
         """
         Executing the inference step.
 
         Args:
             request_ids (List[int], optional): The request ID. Defaults to None.
             prompts (Union[List[str], optional): Input prompts. Defaults to None.
-            prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
-            return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
-            generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
-
-        Returns:
-            Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
         """
 
-        gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
-        prompts = [prompts] if isinstance(prompts, str) else prompts
-        request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
-
-        with torch.inference_mode():
-            if prompts is not None or prompts_token_ids is not None:
-                self.add_request(
-                    request_ids=request_ids,
-                    prompts=prompts,
-                    prompts_token_ids=prompts_token_ids,
-                    **gen_config_dict,
-                )
-
-            output_seqs_list = []
-            total_tokens_list = []
-
-            # intuition: If user provide a generation config, we should replace the existing one.
-            if generation_config is not None:
-                self.generation_config = generation_config
-                self.generation_config_dict = gen_config_dict
-
-            if self.use_spec_dec:
-                assert self.drafter is not None, "Drafter Model is not initialized."
-                while self.request_handler.check_unfinished_seqs():
-                    output_seqs_list += self.steps_spec_dec()
-            else:
-                while self.request_handler.check_unfinished_seqs():
-                    output_seqs_list += self.step()
-
-            output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
-
-            for seq in output_seqs_list:
-                total_tokens_list.append(seq.input_token_id + seq.output_token_id)
-
-            output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
-
-            if return_token_ids:
-                output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
-                return output_str, output_tokens_list
-            else:
-                return output_str
-
-    @property
-    def has_prompt_template(self) -> bool:
-        """ """
-        return self.inference_config.prompt_template is not None
-
-    def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
-        """
-        This method will format the input prompt according to the prompt template given to the InferenceConfig.
-        """
-        assert (
-            self.has_prompt_template
-        ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
-
-        if isinstance(prompts, (list, tuple)):
-            return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
-        elif isinstance(prompts, str):
-            return self.inference_config.prompt_template.format(input_text=prompts)
-        else:
-            raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
+        assert self.engine is not None, "Please init Engine first"
+        return self.engine.generate(request_ids=request_ids, prompts=prompts, *args, **kwargs)
 
     def add_request(
         self,
         request_ids: Union[List[int], int] = None,
         prompts: Union[List[str], str] = None,
-        prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
+        *args,
         **kwargs,
     ) -> None:
         """
@@ -630,168 +98,36 @@ class InferenceEngine:
             request_ids (List[int], optional): The request ID. Defaults to None.
             prompts (Union[List[str], optional): Input prompts. Defaults to None.
             prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
+            kwargs: for LLM, it could be max_length, max_new_tokens, etc
+                    for diffusion, it could be prompt_2, prompt_3, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, negative_prompt_3, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, clip_skip, which aligns with diffusers
         """
+        assert self.engine is not None, "Please init Engine first"
+        self.engine.add_request(request_ids=request_ids, prompts=prompts, *args, **kwargs)
 
-        # apply the prompt template to the input prompts
+    def step(self):
+        assert self.engine is not None, "Please init Engine first"
+        return self.engine.step()
 
-        if self.has_prompt_template and prompts is not None:
-            prompts = self.format_prompt(prompts)
-
-        block_size = self.inference_config.block_size
-
-        if request_ids is not None and not isinstance(request_ids, list):
-            request_ids = [request_ids]
-
-        if prompts is not None and not isinstance(prompts, list):
-            prompts = [prompts]
-
-        if prompts_token_ids is None:
-            assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
-            prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
-                "input_ids"
-            ]
-
-        # list of torch Tensor
-        if isinstance(prompts_token_ids, list):
-            if isinstance(prompts_token_ids[0], torch.Tensor):
-                prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
-        elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
-            prompts_token_ids = prompts_token_ids.tolist()
-        else:
-            raise TypeError(
-                f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
-            )
-
-        assert (
-            len(prompts_token_ids[0]) <= self.inference_config.max_input_len
-        ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
-
-        prompts_num = len(prompts_token_ids)
-
-        for i in range(prompts_num):
-            if request_ids:
-                assert isinstance(
-                    request_ids[0], int
-                ), f"The request_id type must be int, but got {type(request_ids[0])}"
-                assert len(request_ids) == prompts_num
-                request_id = request_ids[i]
+    def __getattr__(self, name):
+        """
+        The Design logic of getattr, setattr:
+        1. Since InferenceEngine is a wrapper for DiffusionEngine/LLMEngine, we hope to invoke all the member of DiffusionEngine/LLMEngine like we just call the member of InferenceEngine.
+        2. When we call the __init__ of InferenceEngine, we don't want to setattr using self.__dict__["xxx"] = xxx, we want to use origin ways like self.xxx = xxx
+        So we set the attribute `_initialized`. And after initialized, if we couldn't get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine)
+        """
+        if self.__dict__.get("_initialized", False):
+            if name in self.__dict__:
+                return self.__dict__[name]
             else:
-                request_id = next(self.counter)
-            if prompts == None:
-                prompt = None
+                return getattr(self.engine, name)
+        else:
+            return self.__dict__[name]
+
+    def __setattr__(self, name, value):
+        if self.__dict__.get("_initialized", False):
+            if name in self.__dict__:
+                self.__dict__[name] = value
             else:
-                prompt = prompts[i]
-
-            max_length = kwargs.get("max_length", None)
-            max_new_tokens = kwargs.get("max_new_tokens", None)
-            if max_length is None and max_new_tokens is None:
-                max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
-            elif max_length is not None:
-                max_new_tokens = max_length - len(prompts_token_ids[i])
-
-            if not self.inference_config.enable_streamingllm:
-                assert (
-                    self.inference_config.max_output_len >= max_new_tokens
-                ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
-
-            sequence = Sequence(
-                request_id,
-                prompt,
-                prompts_token_ids[i],
-                block_size,
-                None,
-                self.tokenizer.eos_token_id,
-                self.tokenizer.pad_token_id,
-                max_output_len=max_new_tokens,
-                ignore_eos=self.inference_config.ignore_eos,
-            )
-            self.request_handler.add_sequence(sequence)
-
-    def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
-        input_ids = batch.get_1D_inputs()
-        sequence_lengths = batch.get_sequence_lengths()
-
-        if batch.is_prompts:
-            n_tokens = sequence_lengths.sum().item()
+                setattr(self.engine, name, value)
         else:
-            n_tokens = batch.current_batch_size
-            if batch.use_spec_dec:
-                n_tokens = batch.num_tokens_to_verify + 1
-                assert n_tokens == input_ids.size(0)
-                n_tokens = n_tokens * batch.current_batch_size
-        output_tensor = torch.zeros(
-            (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
-        )
-
-        batch_token_ids = None
-        if (
-            self.generation_config.repetition_penalty != 1.0
-            or self.generation_config.no_repeat_ngram_size > 0
-            or self.generation_config.forced_eos_token_id is not None
-        ):
-            batch_token_ids = batch.batch_token_ids
-
-        # only when we have the graph for specific decoding batch size can we use the cuda graph for inference
-        use_cuda_graph = False
-        if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
-            use_cuda_graph = True
-
-        input_meta_data = InputMetaData(
-            block_tables=batch.get_block_table_tensor(),
-            sequence_lengths=sequence_lengths,
-            fd_inter_tensor=batch.fd_inter_tensor,
-            batch_size=batch.current_batch_size,
-            is_prompts=batch.is_prompts,
-            use_cuda_kernel=self.inference_config.use_cuda_kernel,
-            use_cuda_graph=use_cuda_graph,
-            high_precision=self.high_precision,
-            kv_seq_len=sequence_lengths.max().item(),
-            head_dim=batch.head_dim,
-            dtype=batch.dtype,
-            use_spec_dec=batch.use_spec_dec,
-            num_tokens_to_verify=batch.num_tokens_to_verify,
-            batch_token_ids=batch_token_ids,
-        )
-
-        return input_ids, output_tensor, input_meta_data
-
-    def step(self) -> List[str]:
-        """
-        In each step, do the follows:
-            1. Run RequestHandler.schedule() and get the batch used for inference.
-            2. Get the input, inputinfo and output placeholder from the batchbucket
-            3. Run model to generate the next token
-            4. Update waiting list and running list in RequestHandler and get finished sequences.
-            5. Decode and return finished sequences.
-
-        Returns:
-            List[str]: Decoded finished sequences generated by one step.
-        """
-
-        batch = self.request_handler.schedule()
-
-        input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
-
-        if input_meta_data.use_cuda_graph:
-            model_executable = self.graph_runners[input_meta_data.batch_size]
-        else:
-            model_executable = self.model
-
-        # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
-        logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
-        if self.inference_config.pad_input:
-            logits = logits[:, -1, :]
-
-        if self.inference_config.enable_streamingllm:
-            updated_block_ids = batch.streamingllm_update_batch(
-                self.inference_config.start_token_size, self.inference_config.generated_token_size
-            )
-            self.request_handler.streamingllm_free_block_tables(updated_block_ids)
-
-        next_tokens = search_tokens(
-            self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
-        )
-        self.request_handler.append_next_tokens(next_tokens)
-        finished_sequences = self.request_handler.update()
-
-        return finished_sequences
+            self.__dict__[name] = value
diff --git a/colossalai/inference/core/llm_engine.py b/colossalai/inference/core/llm_engine.py
new file mode 100644
index 000000000..b973d371d
--- /dev/null
+++ b/colossalai/inference/core/llm_engine.py
@@ -0,0 +1,758 @@
+import time
+from itertools import count
+from typing import Dict, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import distributed as dist
+from transformers import (
+    AutoConfig,
+    AutoModelForCausalLM,
+    GenerationConfig,
+    PreTrainedTokenizer,
+    PreTrainedTokenizerFast,
+)
+from transformers.models.llama.modeling_llama import LlamaForCausalLM
+
+from colossalai.accelerator import get_accelerator
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.inference.batch_bucket import BatchBucket
+from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
+from colossalai.inference.graph_runner import CUDAGraphRunner
+from colossalai.inference.modeling.policy import model_policy_map
+from colossalai.inference.sampler import search_tokens
+from colossalai.inference.spec import Drafter, GlideInput
+from colossalai.inference.struct import Sequence
+from colossalai.inference.utils import get_model_size, has_index_file
+from colossalai.interface import ModelWrapper
+from colossalai.lazy import LazyInitContext
+from colossalai.logging import get_dist_logger
+from colossalai.shardformer.policies.base_policy import Policy
+
+from .base_engine import BaseEngine
+from .request_handler import RequestHandler
+
+PP_AXIS, TP_AXIS = 0, 1
+
+_supported_models = {
+    "LlamaForCausalLM": LlamaForCausalLM,
+    "BaichuanForCausalLM": AutoModelForCausalLM,
+}
+
+_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
+
+
+class LLMEngine(BaseEngine):
+    """
+    InferenceEngine which manages the inference process..
+
+    Args:
+        model_or_path (nn.Module or str): Path or nn.Module of this model.
+        tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
+        inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
+        verbose (bool): Determine whether or not to log the generation process.
+        model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
+    """
+
+    def __init__(
+        self,
+        model_or_path: nn.Module | str,
+        tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None,
+        inference_config: InferenceConfig = None,
+        verbose: bool = False,
+        model_policy: Policy | type[Policy] = None,
+    ) -> None:
+        self.inference_config = inference_config
+        self.dtype = inference_config.dtype
+        self.high_precision = inference_config.high_precision
+
+        self.verbose = verbose
+        self.logger = get_dist_logger(__name__)
+        self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
+
+        self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
+
+        self.generation_config = inference_config.to_generation_config(self.model_config)
+        self.generation_config_dict = self.generation_config.to_dict()
+
+        self.tokenizer = tokenizer
+        self.tokenizer.pad_token = self.tokenizer.eos_token
+
+        self.request_handler = RequestHandler(self.inference_config, self.model_config)
+        self.k_cache, self.v_cache = self.request_handler.get_kvcache()
+        # DISCUSS maybe move this into batch info?
+
+        self.counter = count()
+
+        self.use_cuda_graph = self.inference_config.use_cuda_graph
+        if self.use_cuda_graph:
+            self.graph_runners: Dict[int, CUDAGraphRunner] = {}
+            self.graph_memory_pool = None  # Set during graph capture.
+            if verbose:
+                self.logger.info("Colossal AI CUDA Graph Capture on")
+
+            self.capture_model(self.k_cache, self.v_cache)
+
+        # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
+        self.use_spec_dec = self.inference_config.use_spec_dec
+
+        self.drafter_model = None
+        self.drafter = None
+        self.use_glide = False
+        self.n_spec_tokens = self.inference_config.max_n_spec_tokens
+
+        self._verify_args()
+
+    def init_model(
+        self,
+        model_or_path: Union[nn.Module, str],
+        model_policy: Union[Policy, Type[Policy]] = None,
+        model_shard_infer_config: ModelShardInferenceConfig = None,
+    ):
+        """
+        Shard model or/and Load weight
+
+        Args:
+            model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
+            model_policy (Policy): the policy to replace the model.
+            model_inference_config: the configuration for modeling initialization when inference.
+            model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
+        """
+        pretrained_path = None
+        if isinstance(model_or_path, str):
+            import colossalai.interface.pretrained as pretrained_utils
+
+            try:
+                hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
+                arch = getattr(hf_config, "architectures")[0]
+                if arch in _supported_models.keys():
+                    if arch == "BaichuanForCausalLM":
+                        self.logger.warning(
+                            "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
+                        )
+                    ctx = LazyInitContext(default_device="cuda")
+                    with ctx:
+                        model = _supported_models[arch].from_pretrained(
+                            model_or_path, trust_remote_code=True, torch_dtype=self.dtype
+                        )
+                    pretrained_path = pretrained_utils.get_pretrained_path(model)
+                else:
+                    # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
+                    raise ValueError(f"Model {arch} is not supported.")
+
+            except Exception as e:
+                self.logger.error(
+                    f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
+                )
+        else:
+            model = model_or_path
+
+        self.model_config = model.config
+
+        torch.cuda.empty_cache()
+        init_gpu_memory = torch.cuda.mem_get_info()[0]
+
+        self.device = get_accelerator().get_current_device()
+        if self.verbose:
+            self.logger.info(f"the device is {self.device}")
+
+        model = model.to(self.dtype).eval()
+
+        if self.verbose:
+            self.logger.info(
+                f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
+            )
+
+        if model_policy is None:
+            prefix = "nopadding" if not self.inference_config.pad_input else "padding"
+            model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
+            model_policy = model_policy_map.get(model_policy_key)
+
+        if not isinstance(model_policy, Policy):
+            try:
+                model_policy = model_policy()
+            except Exception as e:
+                raise ValueError(f"Unable to instantiate model policy: {e}")
+
+        assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
+        pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
+        tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
+
+        self.model = self._shardformer(
+            model,
+            model_policy,
+            model_shard_infer_config,
+            None,
+            tp_group=tp_group,
+        )
+
+        self.model = ModelWrapper(model).to(self.device)
+
+        if self.verbose:
+            self.logger.info(
+                f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
+            )
+
+        if pretrained_path:
+            from colossalai.inference.core.plugin import InferCheckpoint_io
+
+            cpt_io = InferCheckpoint_io()
+            if_has_index_file, model_index_file = has_index_file(pretrained_path)
+            assert if_has_index_file, "the model path is invalid"
+            cpt_io.load_model(self.model, model_index_file)
+
+        free_gpu_memory, _ = torch.cuda.mem_get_info()
+        peak_memory = init_gpu_memory - free_gpu_memory
+        if self.verbose:
+            self.logger.info(
+                f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
+            )
+
+    @torch.inference_mode()
+    def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
+        assert self.use_cuda_graph, "please turn on the cuda graph"
+
+        if self.verbose:
+            self.logger.info("Colossal AI CUDA Graph Capture begin")
+
+        t_capture_begin = time.perf_counter()
+
+        block_size = self.inference_config.block_size
+        head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
+
+        # Prepare dummy inputs. These will be reused for all batch sizes.
+        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
+        max_context_len_to_capture = self.inference_config.max_context_len_to_capture
+        max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
+        input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
+        # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
+        self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
+        self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
+        self.graph_block_tables[0, :] = np.arange(
+            0, max_num_blocks
+        )  # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
+        block_tables = torch.from_numpy(self.graph_block_tables).cuda()
+        output_tensor = torch.zeros(
+            (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
+        )
+        fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
+
+        max_num_seqs = self.inference_config.max_batch_size
+        batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
+        sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
+        # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
+        sequence_lengths[0] = torch.tensor(
+            self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
+        ).cuda()
+
+        # NOTE: Capturing the largest batch size first may help reduce the
+        # memory usage of CUDA graph.
+        for batch_size in reversed(batch_size_capture_list):
+            if self.verbose:
+                self.logger.info(f"batch size {batch_size} graph capturing")
+
+            input_meta_data = InputMetaData(
+                block_tables=block_tables[:batch_size],
+                sequence_lengths=sequence_lengths[:batch_size],
+                fd_inter_tensor=fd_inter_tensor,
+                batch_size=batch_size,
+                is_prompts=False,
+                use_cuda_graph=True,
+                high_precision=False,
+                kv_seq_len=sequence_lengths[:batch_size].max().item(),
+                head_dim=head_dim,
+                dtype=self.dtype,
+            )
+
+            graph_runner = CUDAGraphRunner(self.model)
+            graph_runner.capture(
+                input_tokens_ids[:batch_size],
+                output_tensor[:batch_size],
+                input_meta_data,
+                k_caches=k_cache,
+                v_caches=v_cache,
+                memory_pool=self.graph_memory_pool,
+            )
+            self.graph_memory_pool = graph_runner.graph.pool()
+            self.graph_runners[batch_size] = graph_runner
+
+        t_capture_end = time.perf_counter()
+
+        if self.verbose:
+            self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
+
+    def _verify_args(self) -> None:
+        """Verify the input args"""
+        if not isinstance(self.inference_config, InferenceConfig):
+            raise TypeError("Invalid type of inference config provided.")
+        if not isinstance(self.model, nn.Module):
+            raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
+        if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
+            raise TypeError(
+                f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
+            )
+        if isinstance(self.model, ModelWrapper):
+            model = self.model.module
+        assert (
+            model.__class__.__name__ in _supported_models.keys()
+        ), f"Model {self.model.__class__.__name__} is not supported."
+
+    def enable_spec_dec(
+        self,
+        drafter_model: nn.Module = None,
+        n_spec_tokens: int = None,
+        use_glide_drafter: bool = False,
+    ) -> None:
+        """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
+
+        Args:
+            drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
+                If provided, the previous drafter and drafter model, if exist, will be overwritten.
+            n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
+                If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
+            use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
+                If True, the drafter model will be replaced by a glide model.
+
+        ```python
+        ...
+        engine = InferenceEngine(model, tokenizer, inference_config)
+
+        engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
+        engine.generate(...)  # Speculative Decoding
+
+        engine.disable_spec_dec()
+        engine.generate(...)  # Normal generation
+
+        engine.enable_spec_dec()
+        engine.generate(...)  # Speculative-Decoding using previously set drafter model and number of spec tokens
+        engine.clear_spec_dec()
+        ```
+        """
+
+        if drafter_model is None and self.drafter is None:
+            raise ValueError("Drafter not initialized. Please provide a Drafter Model")
+        if n_spec_tokens is not None:
+            assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
+            self.n_spec_tokens = n_spec_tokens
+        if drafter_model is not None:
+            assert isinstance(drafter_model, nn.Module)
+            # overwrite the drafter, if exists
+            self.clear_spec_dec()
+            self.drafter_model = drafter_model
+            self.drafter = Drafter(
+                self.drafter_model,
+                self.tokenizer,
+                device=self.device,
+                dtype=self.dtype,
+            )
+
+            # check if the provided drafter model is compatible with GLIDE structure
+            # when `use_glide_drafter` is set to True
+            if (
+                use_glide_drafter
+                and hasattr(drafter_model, "model")
+                and hasattr(drafter_model.model, "layers")
+                and hasattr(drafter_model.model.layers[0], "cross_attn")
+            ):
+                self.use_glide = use_glide_drafter
+            elif use_glide_drafter:
+                self.logger.warning(
+                    f"`use_glide_drafter` is provided as {use_glide_drafter}, "
+                    f"but the provided drafter model is not compatible with GLIDE structure."
+                    f"Falling back to use the default drafter model (non-GLIDE)."
+                )
+        self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
+        # using speculative decoding for subsequent generations
+        self.use_spec_dec = True
+
+    def disable_spec_dec(self) -> None:
+        """Disable using speculative decoding for subsequent generations."""
+        self.request_handler.unset_spec_dec_mode()
+        # set back to the maximum number of tokens to speculate
+        self.n_spec_tokens = self.inference_config.max_n_spec_tokens
+        self.use_glide = False
+        self.use_spec_dec = False
+
+    def clear_spec_dec(self) -> None:
+        """Clear relatable structures of speculative decoding, if exist."""
+        if self.use_spec_dec:
+            self.disable_spec_dec()
+        if self.drafter_model or self.drafter:
+            self.drafter_model = None
+            self.drafter = None
+            torch.cuda.empty_cache()
+        self.use_glide = False
+        self.use_spec_dec = False
+
+    def steps_spec_dec(self) -> List[Sequence]:
+        """
+        Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
+        with many steps of speculating by a drafter model as well as verifying by a main model.
+
+        Returns:
+            List[Sequence]: finished sequences generated by one step.
+        """
+        batch = self.request_handler.schedule()  # prefill batch
+        assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
+
+        input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
+
+        if input_meta_data.use_cuda_graph:
+            model_executable = self.graph_runners[input_meta_data.batch_size]
+        else:
+            model_executable = self.model
+
+        # 1. Prefill small model (Drafter) - fill past kv cache for drafter model
+        # NOTE For glide drafter models, we won't actually apply glide during prefill stage
+        drafter_out = self.drafter.speculate(input_token_ids, 1, None)
+        next_token_ids_spec = drafter_out.next_tokens
+        drafter_past_key_values = drafter_out.past_key_values
+
+        # 2. Prefill main model (Verifier) - fill past kv cache for main model
+        logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
+        next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
+        # append new inputs to the batch, temporarily
+        batch.append_batch_tokens(next_tokens)
+        self.request_handler.allocate_batch_spec_dec(batch, 1)
+        already_allocated_kv_len = batch.seq_lengths[0].item()
+        input_token_ids = batch.get_1D_inputs_spec_dec(1)
+
+        finished_sequences = self.request_handler.update()
+
+        while True:
+            # HACK Retrieve the running batch
+            #      Using RequestHandler.schedule here will re-allocate same kv cache for the batch
+            batch = self.request_handler.running_bb  # running batch
+            assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
+
+            # 3. Decoding - Drafter model speculates `n` tokens
+            glide_input = None
+            if self.use_glide:
+                glide_input = GlideInput(
+                    batch.get_block_table_tensor(),
+                    self.k_cache[-1],  # use kv cahces of the last layer
+                    self.v_cache[-1],
+                    batch.get_sequence_lengths(),
+                    n_spec_tokens=self.n_spec_tokens,
+                )
+
+            drafter_out = self.drafter.speculate(
+                input_token_ids,
+                self.n_spec_tokens,
+                drafter_past_key_values,
+                glide_input=glide_input,
+            )
+            next_token_ids_spec = drafter_out.next_tokens
+            drafter_past_key_values = drafter_out.past_key_values
+            drafter_spec_length = drafter_out.speculated_length
+
+            for next_token_id_spec in next_token_ids_spec:
+                self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
+            cur_length = batch.seq_lengths[0].item()
+            if already_allocated_kv_len < cur_length:
+                self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
+                already_allocated_kv_len = cur_length
+
+            # 4. Decoding - Main model verifies `n` tokens in parallel
+            if drafter_spec_length < batch.num_tokens_to_verify:
+                batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
+            input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
+            logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
+
+            next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
+
+            # 5. Compare and process the results
+            diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
+            n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
+
+            # revoke appended tokens for each Sequence in the current batch
+            batch.revoke_batch_tokens(drafter_spec_length - n_matches)  # revoke drafted tokens
+
+            # append the last correct token generated by the main model
+            self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
+
+            # trim past key values of the drafter model
+            drafter_past_key_values = Drafter.trim_kv_cache(
+                drafter_past_key_values, drafter_spec_length - n_matches - 1
+            )
+
+            # prepare inputs for the next round of speculation
+            n = 1 if n_matches < drafter_spec_length else 2
+            input_token_ids = batch.get_1D_inputs_spec_dec(n)
+
+            self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
+            finished_sequences = self.request_handler.update()
+            if len(finished_sequences) > 0:
+                break
+
+        # Reset back the number of speculated tokens of the batch,
+        # this is used to handle the last round of speculation, in which case the number of speculated tokens
+        # by the drafter is less than the number of speculated tokens set to the engine.
+        batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
+
+        return finished_sequences
+
+    def generate(
+        self,
+        request_ids: Union[List[int], int] = None,
+        prompts: Union[List[str], str] = None,
+        prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
+        return_token_ids: bool = False,
+        generation_config: Optional[GenerationConfig] = None,
+    ) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
+        """
+        Executing the inference step.
+
+        Args:
+            request_ids (List[int], optional): The request ID. Defaults to None.
+            prompts (Union[List[str], optional): Input prompts. Defaults to None.
+            prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
+            return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
+            generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
+
+        Returns:
+            Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
+        """
+
+        gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
+        prompts = [prompts] if isinstance(prompts, str) else prompts
+        request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
+
+        with torch.inference_mode():
+            if prompts is not None or prompts_token_ids is not None:
+                self.add_request(
+                    request_ids=request_ids,
+                    prompts=prompts,
+                    prompts_token_ids=prompts_token_ids,
+                    **gen_config_dict,
+                )
+
+            output_seqs_list = []
+            total_tokens_list = []
+
+            # intuition: If user provide a generation config, we should replace the existing one.
+            if generation_config is not None:
+                self.generation_config = generation_config
+                self.generation_config_dict = gen_config_dict
+
+            if self.use_spec_dec:
+                assert self.drafter is not None, "Drafter Model is not initialized."
+                while self.request_handler.check_unfinished_reqs():
+                    output_seqs_list += self.steps_spec_dec()
+            else:
+                while self.request_handler.check_unfinished_reqs():
+                    output_seqs_list += self.step()
+
+            output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
+
+            for seq in output_seqs_list:
+                total_tokens_list.append(seq.input_token_id + seq.output_token_id)
+
+            output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
+
+            if return_token_ids:
+                output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
+                return output_str, output_tokens_list
+            else:
+                return output_str
+
+    @property
+    def has_prompt_template(self) -> bool:
+        """ """
+        return self.inference_config.prompt_template is not None
+
+    def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
+        """
+        This method will format the input prompt according to the prompt template given to the InferenceConfig.
+        """
+        assert (
+            self.has_prompt_template
+        ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
+
+        if isinstance(prompts, (list, tuple)):
+            return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
+        elif isinstance(prompts, str):
+            return self.inference_config.prompt_template.format(input_text=prompts)
+        else:
+            raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
+
+    def add_request(
+        self,
+        request_ids: Union[List[int], int] = None,
+        prompts: Union[List[str], str] = None,
+        prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
+        **kwargs,
+    ) -> None:
+        """
+        Add requests.
+
+        Args:
+            request_ids (List[int], optional): The request ID. Defaults to None.
+            prompts (Union[List[str], optional): Input prompts. Defaults to None.
+            prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
+        """
+
+        # apply the prompt template to the input prompts
+
+        if self.has_prompt_template and prompts is not None:
+            prompts = self.format_prompt(prompts)
+
+        block_size = self.inference_config.block_size
+
+        if request_ids is not None and not isinstance(request_ids, list):
+            request_ids = [request_ids]
+
+        if prompts is not None and not isinstance(prompts, list):
+            prompts = [prompts]
+
+        if prompts_token_ids is None:
+            assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
+            prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
+                "input_ids"
+            ]
+
+        # list of torch Tensor
+        if isinstance(prompts_token_ids, list):
+            if isinstance(prompts_token_ids[0], torch.Tensor):
+                prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
+        elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
+            prompts_token_ids = prompts_token_ids.tolist()
+        else:
+            raise TypeError(
+                f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
+            )
+
+        assert (
+            len(prompts_token_ids[0]) <= self.inference_config.max_input_len
+        ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
+
+        prompts_num = len(prompts_token_ids)
+
+        for i in range(prompts_num):
+            if request_ids:
+                assert isinstance(
+                    request_ids[0], int
+                ), f"The request_id type must be int, but got {type(request_ids[0])}"
+                assert len(request_ids) == prompts_num
+                request_id = request_ids[i]
+            else:
+                request_id = next(self.counter)
+            if prompts == None:
+                prompt = None
+            else:
+                prompt = prompts[i]
+
+            max_length = kwargs.get("max_length", None)
+            max_new_tokens = kwargs.get("max_new_tokens", None)
+            if max_length is None and max_new_tokens is None:
+                max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
+            elif max_length is not None:
+                max_new_tokens = max_length - len(prompts_token_ids[i])
+
+            if not self.inference_config.enable_streamingllm:
+                assert (
+                    self.inference_config.max_output_len >= max_new_tokens
+                ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
+
+            sequence = Sequence(
+                request_id,
+                prompt,
+                prompts_token_ids[i],
+                block_size,
+                None,
+                self.tokenizer.eos_token_id,
+                self.tokenizer.pad_token_id,
+                max_output_len=max_new_tokens,
+                ignore_eos=self.inference_config.ignore_eos,
+            )
+            self.request_handler.add_sequence(sequence)
+
+    def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
+        input_ids = batch.get_1D_inputs()
+        sequence_lengths = batch.get_sequence_lengths()
+
+        if batch.is_prompts:
+            n_tokens = sequence_lengths.sum().item()
+        else:
+            n_tokens = batch.current_batch_size
+            if batch.use_spec_dec:
+                n_tokens = batch.num_tokens_to_verify + 1
+                assert n_tokens == input_ids.size(0)
+                n_tokens = n_tokens * batch.current_batch_size
+        output_tensor = torch.zeros(
+            (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
+        )
+
+        batch_token_ids = None
+        if (
+            self.generation_config.repetition_penalty != 1.0
+            or self.generation_config.no_repeat_ngram_size > 0
+            or self.generation_config.forced_eos_token_id is not None
+        ):
+            batch_token_ids = batch.batch_token_ids
+
+        # only when we have the graph for specific decoding batch size can we use the cuda graph for inference
+        use_cuda_graph = False
+        if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
+            use_cuda_graph = True
+
+        input_meta_data = InputMetaData(
+            block_tables=batch.get_block_table_tensor(),
+            sequence_lengths=sequence_lengths,
+            fd_inter_tensor=batch.fd_inter_tensor,
+            batch_size=batch.current_batch_size,
+            is_prompts=batch.is_prompts,
+            use_cuda_kernel=self.inference_config.use_cuda_kernel,
+            use_cuda_graph=use_cuda_graph,
+            high_precision=self.high_precision,
+            kv_seq_len=sequence_lengths.max().item(),
+            head_dim=batch.head_dim,
+            dtype=batch.dtype,
+            use_spec_dec=batch.use_spec_dec,
+            num_tokens_to_verify=batch.num_tokens_to_verify,
+            batch_token_ids=batch_token_ids,
+        )
+
+        return input_ids, output_tensor, input_meta_data
+
+    def step(self) -> List[str]:
+        """
+        In each step, do the follows:
+            1. Run RequestHandler.schedule() and get the batch used for inference.
+            2. Get the input, inputinfo and output placeholder from the batchbucket
+            3. Run model to generate the next token
+            4. Update waiting list and running list in RequestHandler and get finished sequences.
+            5. Decode and return finished sequences.
+
+        Returns:
+            List[str]: Decoded finished sequences generated by one step.
+        """
+
+        batch = self.request_handler.schedule()
+
+        input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
+
+        if input_meta_data.use_cuda_graph:
+            model_executable = self.graph_runners[input_meta_data.batch_size]
+        else:
+            model_executable = self.model
+
+        # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
+        logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
+        if self.inference_config.pad_input:
+            logits = logits[:, -1, :]
+
+        if self.inference_config.enable_streamingllm:
+            updated_block_ids = batch.streamingllm_update_batch(
+                self.inference_config.start_token_size, self.inference_config.generated_token_size
+            )
+            self.request_handler.streamingllm_free_block_tables(updated_block_ids)
+
+        next_tokens = search_tokens(
+            self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
+        )
+        self.request_handler.append_next_tokens(next_tokens)
+        finished_sequences = self.request_handler.update()
+
+        return finished_sequences
diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py
index 512eaea71..393347c31 100644
--- a/colossalai/inference/core/request_handler.py
+++ b/colossalai/inference/core/request_handler.py
@@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket
 from colossalai.inference.config import InferenceConfig
 from colossalai.inference.flash_decoding_utils import FDIntermTensors
 from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
-from colossalai.inference.struct import RequestStatus, Sequence
+from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence
 from colossalai.logging import get_dist_logger
 
 logger = get_dist_logger(__name__)
@@ -98,7 +98,46 @@ class RunningList:
             self._decoding[seq_id] = self._prefill.pop(seq_id)
 
 
-class RequestHandler:
+class NaiveRequestHandler:
+    def __init__(self) -> None:
+        self.running_list: List[DiffusionSequence] = []
+        self.waiting_list: List[str] = []
+
+    def _has_waiting(self) -> bool:
+        return any(lst for lst in self.waiting_list)
+
+    def _has_running(self) -> bool:
+        return any(lst for lst in self.running_list)
+
+    def check_unfinished_reqs(self):
+        return self._has_waiting() or self._has_running()
+
+    def add_sequence(self, seq: DiffusionSequence):
+        """
+        Add the request to waiting list.
+        """
+        assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists."
+        self.waiting_list.append(seq)
+
+    def _find_sequence(self, request_id: int) -> DiffusionSequence:
+        """
+        Find the request by request_id.
+        """
+        for lst in enumerate(self.waiting_list + self.running_list):
+            for seq in lst:
+                if seq.request_id == request_id:
+                    return seq
+        return None
+
+    def schedule(self):
+        ret = None
+        if self._has_waiting:
+            ret = self.waiting_list[0]
+            self.waiting_list = self.waiting_list[1:]
+        return ret
+
+
+class RequestHandler(NaiveRequestHandler):
     """
     RequestHandler is the core for handling existing requests and updating current batch.
     During generation process, we call schedule function each iteration to update current batch.
@@ -176,12 +215,12 @@ class RequestHandler:
             generated_token_size=inference_config.generated_token_size,
         )
 
+    def _has_running(self) -> bool:
+        return not self.running_bb.is_empty()
+
     def _init_cache(self, model_config):
         self.cache_manager = KVCacheManager(self.inference_config, model_config)
 
-    def _has_waiting(self) -> bool:
-        return any(lst for lst in self.waiting_list)
-
     def get_kvcache(self):
         return self.cache_manager.get_kv_cache()
 
@@ -318,7 +357,7 @@ class RequestHandler:
             if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
                 seq.mark_finished()
 
-    def check_unfinished_seqs(self) -> bool:
+    def check_unfinished_reqs(self) -> bool:
         return self._has_waiting() or not self.running_list.is_empty()
 
     def total_requests_in_batch_bucket(self) -> int:
diff --git a/colossalai/inference/modeling/models/diffusion.py b/colossalai/inference/modeling/models/diffusion.py
new file mode 100644
index 000000000..9dc90733d
--- /dev/null
+++ b/colossalai/inference/modeling/models/diffusion.py
@@ -0,0 +1,54 @@
+import inspect
+import types
+
+import torch
+from torch import nn
+
+
+class DiffusionPipe(nn.Module):
+    """
+    This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property.
+    """
+
+    def __init__(self, source_obj) -> None:
+        super(DiffusionPipe, self).__init__()
+
+        for k, v in source_obj.__dict__.items():
+            if isinstance(v, nn.Module):
+                self.add_module(k, v)
+            else:
+                setattr(self, k, v)
+
+        skip_list = ["_execution_device", "to", "device"]  # this
+
+        for name, member in inspect.getmembers(source_obj.__class__):
+            if name in skip_list:
+                continue
+            if not name.startswith("__") and not name.endswith("__"):
+                if isinstance(member, property):
+                    setattr(self.__class__, name, member)
+                elif inspect.isfunction(member) or inspect.ismethod(member):
+                    bound_method = types.MethodType(member, self)
+                    setattr(self, name, bound_method)
+                elif not callable(member) and not isinstance(member, property):
+                    setattr(self, name, member)
+            elif name == "__call__":
+                bound_method = types.MethodType(member, self)
+                setattr(self, "_forward", bound_method)
+
+    @property
+    def _execution_device(self):
+        r"""
+        Returns the device on which the pipeline's models will be executed. After calling
+        [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
+        Accelerate's module hooks.
+        """
+        # return self.device
+        return torch.device("cuda")
+
+    @property
+    def device(self):
+        next(self.parameters()).device
+
+    def forward(self, *args, **kwargs):
+        return self._forward(*args, **kwargs)
diff --git a/colossalai/inference/modeling/models/pixart_alpha.py b/colossalai/inference/modeling/models/pixart_alpha.py
new file mode 100644
index 000000000..d5774946e
--- /dev/null
+++ b/colossalai/inference/modeling/models/pixart_alpha.py
@@ -0,0 +1,220 @@
+# Code adapted from:
+# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+
+from typing import Callable, List, Optional, Union
+
+import PIL.Image
+import torch
+from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
+    ASPECT_RATIO_256_BIN,
+    ASPECT_RATIO_512_BIN,
+    ASPECT_RATIO_1024_BIN,
+)
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+
+from colossalai.logging import get_dist_logger
+
+from .diffusion import DiffusionPipe
+
+logger = get_dist_logger(__name__)
+
+
+@torch.no_grad()
+def pixart_alpha_forward(
+    self: DiffusionPipe,
+    prompt: Union[str, List[str]] = None,
+    negative_prompt: str = "",
+    num_inference_steps: int = 20,
+    timesteps: List[int] = None,
+    sigmas: List[float] = None,
+    guidance_scale: float = 4.5,
+    num_images_per_prompt: Optional[int] = 1,
+    height: Optional[int] = None,
+    width: Optional[int] = None,
+    eta: float = 0.0,
+    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+    latents: Optional[torch.Tensor] = None,
+    prompt_embeds: Optional[torch.Tensor] = None,
+    prompt_attention_mask: Optional[torch.Tensor] = None,
+    negative_prompt_embeds: Optional[torch.Tensor] = None,
+    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+    output_type: Optional[str] = "pil",
+    return_dict: bool = True,
+    callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+    callback_steps: int = 1,
+    clean_caption: bool = True,
+    use_resolution_binning: bool = True,
+    max_sequence_length: int = 120,
+    **kwargs,
+) -> PIL.Image:
+    # 1. Check inputs. Raise error if not correct
+    height = height or self.transformer.config.sample_size * self.vae_scale_factor
+    width = width or self.transformer.config.sample_size * self.vae_scale_factor
+    if use_resolution_binning:
+        if self.transformer.config.sample_size == 128:
+            aspect_ratio_bin = ASPECT_RATIO_1024_BIN
+        elif self.transformer.config.sample_size == 64:
+            aspect_ratio_bin = ASPECT_RATIO_512_BIN
+        elif self.transformer.config.sample_size == 32:
+            aspect_ratio_bin = ASPECT_RATIO_256_BIN
+        else:
+            raise ValueError("Invalid sample size")
+        orig_height, orig_width = height, width
+        height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+    self.check_inputs(
+        prompt,
+        height,
+        width,
+        negative_prompt,
+        callback_steps,
+        prompt_embeds,
+        negative_prompt_embeds,
+        prompt_attention_mask,
+        negative_prompt_attention_mask,
+    )
+
+    # 2. Default height and width to transformer
+    if prompt is not None and isinstance(prompt, str):
+        batch_size = 1
+    elif prompt is not None and isinstance(prompt, list):
+        batch_size = len(prompt)
+    else:
+        batch_size = prompt_embeds.shape[0]
+
+    device = self._execution_device
+
+    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+    # corresponds to doing no classifier free guidance.
+    do_classifier_free_guidance = guidance_scale > 1.0
+
+    # 3. Encode input prompt
+    (
+        prompt_embeds,
+        prompt_attention_mask,
+        negative_prompt_embeds,
+        negative_prompt_attention_mask,
+    ) = self.encode_prompt(
+        prompt,
+        do_classifier_free_guidance,
+        negative_prompt=negative_prompt,
+        num_images_per_prompt=num_images_per_prompt,
+        device=device,
+        prompt_embeds=prompt_embeds,
+        negative_prompt_embeds=negative_prompt_embeds,
+        prompt_attention_mask=prompt_attention_mask,
+        negative_prompt_attention_mask=negative_prompt_attention_mask,
+        clean_caption=clean_caption,
+        max_sequence_length=max_sequence_length,
+    )
+    if do_classifier_free_guidance:
+        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+    # 4. Prepare timesteps
+    timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas)
+
+    # 5. Prepare latents.
+    latent_channels = self.transformer.config.in_channels
+    latents = self.prepare_latents(
+        batch_size * num_images_per_prompt,
+        latent_channels,
+        height,
+        width,
+        prompt_embeds.dtype,
+        device,
+        generator,
+        latents,
+    )
+
+    # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+    # 6.1 Prepare micro-conditions.
+    added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+    if self.transformer.config.sample_size == 128:
+        resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
+        aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
+        resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
+        aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
+
+        if do_classifier_free_guidance:
+            resolution = torch.cat([resolution, resolution], dim=0)
+            aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
+
+        added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
+
+    # 7. Denoising loop
+    num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+    with self.progress_bar(total=num_inference_steps) as progress_bar:
+        for i, t in enumerate(timesteps):
+            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+            current_timestep = t
+            if not torch.is_tensor(current_timestep):
+                # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+                # This would be a good case for the `match` statement (Python 3.10+)
+                is_mps = latent_model_input.device.type == "mps"
+                if isinstance(current_timestep, float):
+                    dtype = torch.float32 if is_mps else torch.float64
+                else:
+                    dtype = torch.int32 if is_mps else torch.int64
+                current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
+            elif len(current_timestep.shape) == 0:
+                current_timestep = current_timestep[None].to(latent_model_input.device)
+            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+            current_timestep = current_timestep.expand(latent_model_input.shape[0])
+
+            # predict noise model_output
+            noise_pred = self.transformer(
+                latent_model_input,
+                encoder_hidden_states=prompt_embeds,
+                encoder_attention_mask=prompt_attention_mask,
+                timestep=current_timestep,
+                added_cond_kwargs=added_cond_kwargs,
+                return_dict=False,
+            )[0]
+
+            # perform guidance
+            if do_classifier_free_guidance:
+                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+            # learned sigma
+            if self.transformer.config.out_channels // 2 == latent_channels:
+                noise_pred = noise_pred.chunk(2, dim=1)[0]
+            else:
+                noise_pred = noise_pred
+
+            # compute previous image: x_t -> x_t-1
+            if num_inference_steps == 1:
+                # For DMD one step sampling: https://arxiv.org/abs/2311.18828
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
+            else:
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+            # call the callback, if provided
+            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                progress_bar.update()
+                if callback is not None and i % callback_steps == 0:
+                    step_idx = i // getattr(self.scheduler, "order", 1)
+                    callback(step_idx, t, latents)
+
+    output_type = "pil"  # TODO(@lry89757) temporarily image, please support more return output
+    if not output_type == "latent":
+        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+        if use_resolution_binning:
+            image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
+    else:
+        image = latents
+
+    if not output_type == "latent":
+        image = self.image_processor.postprocess(image, output_type=output_type)
+
+    # Offload all models
+    # self.maybe_free_model_hooks()
+
+    return image
diff --git a/colossalai/inference/modeling/models/stablediffusion3.py b/colossalai/inference/modeling/models/stablediffusion3.py
new file mode 100644
index 000000000..d1c63a6dc
--- /dev/null
+++ b/colossalai/inference/modeling/models/stablediffusion3.py
@@ -0,0 +1,178 @@
+# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
+
+from .diffusion import DiffusionPipe
+
+
+# TODO(@lry89757) temporarily image, please support more return output
+@torch.no_grad()
+def sd3_forward(
+    self: DiffusionPipe,
+    prompt: Union[str, List[str]] = None,
+    prompt_2: Optional[Union[str, List[str]]] = None,
+    prompt_3: Optional[Union[str, List[str]]] = None,
+    height: Optional[int] = None,
+    width: Optional[int] = None,
+    num_inference_steps: int = 28,
+    timesteps: List[int] = None,
+    guidance_scale: float = 7.0,
+    negative_prompt: Optional[Union[str, List[str]]] = None,
+    negative_prompt_2: Optional[Union[str, List[str]]] = None,
+    negative_prompt_3: Optional[Union[str, List[str]]] = None,
+    num_images_per_prompt: Optional[int] = 1,
+    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+    latents: Optional[torch.FloatTensor] = None,
+    prompt_embeds: Optional[torch.FloatTensor] = None,
+    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+    negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+    output_type: Optional[str] = "pil",
+    return_dict: bool = True,
+    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+    clip_skip: Optional[int] = None,
+    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+):
+    height = height or self.default_sample_size * self.vae_scale_factor
+    width = width or self.default_sample_size * self.vae_scale_factor
+
+    # 1. Check inputs. Raise error if not correct
+    self.check_inputs(
+        prompt,
+        prompt_2,
+        prompt_3,
+        height,
+        width,
+        negative_prompt=negative_prompt,
+        negative_prompt_2=negative_prompt_2,
+        negative_prompt_3=negative_prompt_3,
+        prompt_embeds=prompt_embeds,
+        negative_prompt_embeds=negative_prompt_embeds,
+        pooled_prompt_embeds=pooled_prompt_embeds,
+        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+        callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+    )
+
+    self._guidance_scale = guidance_scale
+    self._clip_skip = clip_skip
+    self._joint_attention_kwargs = joint_attention_kwargs
+    self._interrupt = False
+
+    # 2. Define call parameters
+    if prompt is not None and isinstance(prompt, str):
+        batch_size = 1
+    elif prompt is not None and isinstance(prompt, list):
+        batch_size = len(prompt)
+    else:
+        batch_size = prompt_embeds.shape[0]
+
+    device = self._execution_device
+
+    (
+        prompt_embeds,
+        negative_prompt_embeds,
+        pooled_prompt_embeds,
+        negative_pooled_prompt_embeds,
+    ) = self.encode_prompt(
+        prompt=prompt,
+        prompt_2=prompt_2,
+        prompt_3=prompt_3,
+        negative_prompt=negative_prompt,
+        negative_prompt_2=negative_prompt_2,
+        negative_prompt_3=negative_prompt_3,
+        do_classifier_free_guidance=self.do_classifier_free_guidance,
+        prompt_embeds=prompt_embeds,
+        negative_prompt_embeds=negative_prompt_embeds,
+        pooled_prompt_embeds=pooled_prompt_embeds,
+        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+        device=device,
+        clip_skip=self.clip_skip,
+        num_images_per_prompt=num_images_per_prompt,
+    )
+
+    if self.do_classifier_free_guidance:
+        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+        pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+
+    # 4. Prepare timesteps
+    timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+    num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+    self._num_timesteps = len(timesteps)
+
+    # 5. Prepare latent variables
+    num_channels_latents = self.transformer.config.in_channels
+    latents = self.prepare_latents(
+        batch_size * num_images_per_prompt,
+        num_channels_latents,
+        height,
+        width,
+        prompt_embeds.dtype,
+        device,
+        generator,
+        latents,
+    )
+
+    # 6. Denoising loop
+    with self.progress_bar(total=num_inference_steps) as progress_bar:
+        for i, t in enumerate(timesteps):
+            if self.interrupt:
+                continue
+
+            # expand the latents if we are doing classifier free guidance
+            latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+            timestep = t.expand(latent_model_input.shape[0])
+
+            noise_pred = self.transformer(
+                hidden_states=latent_model_input,
+                timestep=timestep,
+                encoder_hidden_states=prompt_embeds,
+                pooled_projections=pooled_prompt_embeds,
+                joint_attention_kwargs=self.joint_attention_kwargs,
+                return_dict=False,
+            )[0]
+
+            # perform guidance
+            if self.do_classifier_free_guidance:
+                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+            # compute the previous noisy sample x_t -> x_t-1
+            latents_dtype = latents.dtype
+            latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+            if latents.dtype != latents_dtype:
+                if torch.backends.mps.is_available():
+                    # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+                    latents = latents.to(latents_dtype)
+
+            if callback_on_step_end is not None:
+                callback_kwargs = {}
+                for k in callback_on_step_end_tensor_inputs:
+                    callback_kwargs[k] = locals()[k]
+                callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                latents = callback_outputs.pop("latents", latents)
+                prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+                negative_pooled_prompt_embeds = callback_outputs.pop(
+                    "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+                )
+
+            # call the callback, if provided
+            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                progress_bar.update()
+
+    if output_type == "latent":
+        image = latents
+
+    else:
+        latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+
+        image = self.vae.decode(latents, return_dict=False)[0]
+        image = self.image_processor.postprocess(image, output_type=output_type)
+
+    return image
diff --git a/colossalai/inference/modeling/policy/__init__.py b/colossalai/inference/modeling/policy/__init__.py
index fa0395590..02ffadd9f 100644
--- a/colossalai/inference/modeling/policy/__init__.py
+++ b/colossalai/inference/modeling/policy/__init__.py
@@ -1,16 +1,22 @@
 from .glide_llama import GlideLlamaModelPolicy
 from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
 from .nopadding_llama import NoPaddingLlamaModelInferPolicy
+from .pixart_alpha import PixArtAlphaInferPolicy
+from .stablediffusion3 import StableDiffusion3InferPolicy
 
 model_policy_map = {
     "nopadding_llama": NoPaddingLlamaModelInferPolicy,
     "nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
     "glide_llama": GlideLlamaModelPolicy,
+    "StableDiffusion3Pipeline": StableDiffusion3InferPolicy,
+    "PixArtAlphaPipeline": PixArtAlphaInferPolicy,
 }
 
 __all__ = [
     "NoPaddingLlamaModelInferPolicy",
     "NoPaddingBaichuanModelInferPolicy",
     "GlideLlamaModelPolicy",
+    "StableDiffusion3InferPolicy",
+    "PixArtAlphaInferPolicy",
     "model_polic_map",
 ]
diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py
new file mode 100644
index 000000000..356056ba7
--- /dev/null
+++ b/colossalai/inference/modeling/policy/pixart_alpha.py
@@ -0,0 +1,34 @@
+from torch import nn
+
+from colossalai.inference.config import RPC_PARAM
+from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
+from colossalai.shardformer.policies.base_policy import Policy
+
+
+class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def module_policy(self):
+        policy = {}
+        self.append_or_create_method_replacement(
+            description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
+        )
+        return policy
+
+    def preprocess(self) -> nn.Module:
+        return self.model
+
+    def postprocess(self):
+        return self.model
+
+    def config_sanity_check(self):
+        pass
+
+    def to_rpc_param(self) -> str:
+        return __class__.__name__
+
+    @staticmethod
+    def from_rpc_param() -> "PixArtAlphaInferPolicy":
+        return PixArtAlphaInferPolicy()
diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py
new file mode 100644
index 000000000..c9877f7dc
--- /dev/null
+++ b/colossalai/inference/modeling/policy/stablediffusion3.py
@@ -0,0 +1,34 @@
+from torch import nn
+
+from colossalai.inference.config import RPC_PARAM
+from colossalai.inference.modeling.models.diffusion import DiffusionPipe
+from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
+from colossalai.shardformer.policies.base_policy import Policy
+
+
+class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def module_policy(self):
+        policy = {}
+        self.append_or_create_method_replacement(
+            description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
+        )
+        return policy
+
+    def preprocess(self) -> nn.Module:
+        return self.model
+
+    def postprocess(self):
+        return self.model
+
+    def config_sanity_check(self):
+        pass
+
+    def to_rpc_param(self) -> str:
+        return __class__.__name__
+
+    @staticmethod
+    def from_rpc_param() -> "StableDiffusion3InferPolicy":
+        return StableDiffusion3InferPolicy()
diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py
index 1a3094a27..65d284296 100644
--- a/colossalai/inference/struct.py
+++ b/colossalai/inference/struct.py
@@ -2,6 +2,7 @@ import enum
 from dataclasses import dataclass
 from typing import Any, List
 
+from colossalai.inference.config import DiffusionGenerationConfig
 from colossalai.logging import get_dist_logger
 
 logger = get_dist_logger(__name__)
@@ -46,6 +47,17 @@ class RequestStatus(enum.Enum):
         return status == RequestStatus.WAITING
 
 
+@dataclass
+class DiffusionSequence:
+    """
+    parameters for diffusion
+    """
+
+    request_id: int
+    prompt: str
+    generation_config: DiffusionGenerationConfig
+
+
 @dataclass
 class Sequence:
     """Store information of input sequence.
diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py
index 332e84d37..f2a0fc037 100644
--- a/colossalai/inference/utils.py
+++ b/colossalai/inference/utils.py
@@ -5,10 +5,12 @@ Utils for model inference
 import math
 import os
 import re
+from enum import Enum
 from pathlib import Path
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
 
 import torch
+from diffusers import DiffusionPipeline
 from torch import nn
 
 from colossalai.logging import get_dist_logger
@@ -159,3 +161,38 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
     except ImportError:
         logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
         return False
+
+
+class ModelType(Enum):
+    DIFFUSION_MODEL = "Diffusion Model"
+    LLM = "Large Language Model (LLM)"
+    UNKNOWN = "Unknown Model Type"
+
+
+def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
+    if isinstance(model_or_path, DiffusionPipeline):
+        return ModelType.DIFFUSION_MODEL
+    elif isinstance(model_or_path, nn.Module):
+        return ModelType.LLM
+    elif isinstance(model_or_path, str):
+        try:
+            from transformers import AutoConfig
+
+            hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
+            return ModelType.LLM
+        except:
+            """
+            model type is not `ModelType.LLM`
+            """
+
+        try:
+            from diffusers import DiffusionPipeline
+
+            DiffusionPipeline.load_config(model_or_path)
+            return ModelType.DIFFUSION_MODEL
+        except:
+            """
+            model type is not `ModelType.DIFFUSION_MODEL`
+            """
+    else:
+        return ModelType.UNKNOWN
diff --git a/examples/inference/stable_diffusion/sd3_generation.py b/examples/inference/stable_diffusion/sd3_generation.py
new file mode 100644
index 000000000..fe989eed7
--- /dev/null
+++ b/examples/inference/stable_diffusion/sd3_generation.py
@@ -0,0 +1,75 @@
+import argparse
+
+from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
+from torch import bfloat16, float16, float32
+
+import colossalai
+from colossalai.cluster import DistCoordinator
+from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
+from colossalai.inference.core.engine import InferenceEngine
+from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
+from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
+
+# For Stable Diffusion 3, we'll use the following configuration
+MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
+POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
+
+TORCH_DTYPE_MAP = {
+    "fp16": float16,
+    "fp32": float32,
+    "bf16": bfloat16,
+}
+
+
+def infer(args):
+    # ==============================
+    # Launch colossalai, setup distributed environment
+    # ==============================
+    colossalai.launch_from_torch()
+    coordinator = DistCoordinator()
+
+    # ==============================
+    # Load model and tokenizer
+    # ==============================
+    model_path_or_name = args.model
+    model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
+
+    # ==============================
+    # Initialize InferenceEngine
+    # ==============================
+    coordinator.print_on_master(f"Initializing Inference Engine...")
+    inference_config = InferenceConfig(
+        dtype=args.dtype,
+        max_batch_size=args.max_batch_size,
+        tp_size=args.tp_size,
+        use_cuda_kernel=args.use_cuda_kernel,
+    )
+    engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
+
+    # ==============================
+    # Generation
+    # ==============================
+    coordinator.print_on_master(f"Generating...")
+    out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
+    out.save("cat.jpg")
+    coordinator.print_on_master(out)
+
+
+# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
+# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
+
+
+if __name__ == "__main__":
+    # ==============================
+    # Parse Arguments
+    # ==============================
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
+    parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
+    parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt")
+    parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size")
+    parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
+    parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
+    args = parser.parse_args()
+
+    infer(args)
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 27bbc3769..b54d1cf91 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -23,3 +23,4 @@ rpyc==6.0.0
 fastapi
 uvicorn==0.29.0
 galore_torch
+diffusers==0.29.0