diff --git a/colossalai/inference/build.sh b/colossalai/inference/build.sh
deleted file mode 100644
index 6a73f6f0b..000000000
--- a/colossalai/inference/build.sh
+++ /dev/null
@@ -1,24 +0,0 @@
-#!/usr/bin/env bash
-
-# install triton 
-pip install triton
-pip install transformers
-
-# install lightllm and flash-attention 
-mkdir 3rdParty
-cd 3rdParty
-git clone https://github.com/ModelTC/lightllm 
-cd lightllm
-git checkout 28c1267cfca536b7b4f28e921e03de735b003039
-pip install -e . 
-cd ..
-
-git clone -recursive https://github.com/Dao-AILab/flash-attention
-cd flash-attention
-pip install -e . 
-
-cd ../../
-
-
-
-
diff --git a/colossalai/inference/engine/modeling/llama.py b/colossalai/inference/engine/modeling/llama.py
index 2dd1858d6..b7bc94d0e 100644
--- a/colossalai/inference/engine/modeling/llama.py
+++ b/colossalai/inference/engine/modeling/llama.py
@@ -27,9 +27,15 @@ except:
     print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
     HAS_LIGHTLLM_KERNEL = False
 
+try:
+    from colossalai.kernel.triton.flash_decoding import token_flash_decoding
+    HAS_TRITON_FLASH_DECODING_KERNEL = True
+except:
+    print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
+    HAS_TRITON_FLASH_DECODING_KERNEL = False
+    
 try:
     from flash_attn import flash_attn_with_kvcache
-
     HAS_FLASH_KERNEL = True
 except:
     HAS_FLASH_KERNEL = False
@@ -42,7 +48,6 @@ def rotate_half(x):
     x2 = x[..., x.shape[-1] // 2 :]
     return torch.cat((-x2, x1), dim=-1)
 
-
 def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
     # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
     cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
@@ -67,7 +72,6 @@ def llama_triton_context_attention(
                 attn_output,
                 infer_state.start_loc,
                 infer_state.seq_len,
-                # infer_state.cache_manager.past_key_values_length,
                 infer_state.max_len_in_batch,
             )
         else:
@@ -78,7 +82,6 @@ def llama_triton_context_attention(
                 attn_output,
                 infer_state.start_loc,
                 infer_state.seq_len,
-                # infer_state.cache_manager.past_key_values_length,
                 infer_state.max_len_in_batch,
             )
     else:
@@ -90,13 +93,20 @@ def llama_triton_context_attention(
             attn_output,
             infer_state.start_loc,
             infer_state.seq_len,
-            # infer_state.cache_manager.past_key_values_length,
             infer_state.max_len_in_batch,
         )
 
-
-def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
-    assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
+def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1):
+    if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
+        token_flash_decoding(q = query_states, 
+                                o_tensor = attn_output, 
+                                infer_state = infer_state, 
+                                q_head_num = q_head_num, 
+                                head_dim = head_dim, 
+                                cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], 
+                                cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id])
+        return 
+    
     if num_key_value_groups == 1:
         token_attention_fwd(
             query_states,
@@ -106,7 +116,6 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
             infer_state.block_loc,
             infer_state.start_loc,
             infer_state.seq_len,
-            # infer_state.cache_manager.past_key_values_length,
             infer_state.max_len_in_batch,
         )
     else:
@@ -118,7 +127,6 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
             infer_state.block_loc,
             infer_state.start_loc,
             infer_state.seq_len,
-            # infer_state.cache_manager.past_key_values_length,
             infer_state.max_len_in_batch,
             infer_state.other_kv_index,
         )
@@ -451,10 +459,14 @@ class LlamaInferenceForwards:
                 )
 
             if HAS_LIGHTLLM_KERNEL:
+                
                 attn_output = torch.empty_like(query_states)
-                llama_triton_token_attention(
-                    query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
-                )
+                llama_triton_token_attention(query_states = query_states, 
+                                             attn_output = attn_output, 
+                                             infer_state = infer_state, 
+                                             num_key_value_groups = self.num_key_value_groups, 
+                                             q_head_num = q_len * self.num_heads, 
+                                             head_dim = self.head_dim)
             else:
                 self.num_heads // self.num_key_value_heads
                 cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py
index 1ad7a80eb..3d9a23d2f 100644
--- a/colossalai/kernel/triton/context_attention.py
+++ b/colossalai/kernel/triton/context_attention.py
@@ -137,6 +137,7 @@ if HAS_TRITON:
             tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
             return
     else:
+        # this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11
         @triton.jit
         def _context_flash_attention_kernel_2(
             Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py
new file mode 100644
index 000000000..9b7b27fa1
--- /dev/null
+++ b/colossalai/kernel/triton/flash_decoding.py
@@ -0,0 +1,50 @@
+# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
+import torch
+try:
+    from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
+    from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
+    HAS_LIGHTLLM_KERNEL = True
+except:
+    print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
+    HAS_LIGHTLLM_KERNEL = False
+
+
+if HAS_LIGHTLLM_KERNEL:
+    def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
+        BLOCK_SEQ = 256
+        batch_size = infer_state.batch_size
+        max_len_in_batch = infer_state.max_len_in_batch
+
+
+        calcu_shape1 = (batch_size, q_head_num, head_dim)
+
+        if getattr(infer_state, 'mid_o', None) is None:
+            infer_state.mid_o = torch.empty([batch_size, 
+                                            q_head_num, 
+                                            max_len_in_batch // BLOCK_SEQ + 1, 
+                                            head_dim], 
+                                            dtype=torch.float32, 
+                                            device="cuda")
+            infer_state.mid_o_logexpsum = torch.empty([batch_size, 
+                                            q_head_num,
+                                            max_len_in_batch // BLOCK_SEQ + 1], 
+                                            dtype=torch.float32, 
+                                            device="cuda")
+
+        mid_o = infer_state.mid_o
+        mid_o_logexpsum = infer_state.mid_o_logexpsum
+
+        flash_decode_stage1(q.view(calcu_shape1),
+                                    cache_k,
+                                    cache_v,
+                                    infer_state.block_loc,
+                                    infer_state.seq_len,
+                                    infer_state.max_len_in_batch,
+                                    mid_o,
+                                    mid_o_logexpsum,
+                                    BLOCK_SEQ)
+        flash_decode_stage2(mid_o,
+                            mid_o_logexpsum, 
+                            infer_state.seq_len, 
+                            o_tensor.view(calcu_shape1), 
+                            BLOCK_SEQ)
diff --git a/examples/inference/hybrid_llama.py b/examples/inference/hybrid_llama.py
index bdfa4e5e8..1bd34afef 100644
--- a/examples/inference/hybrid_llama.py
+++ b/examples/inference/hybrid_llama.py
@@ -75,11 +75,11 @@ def run_tp_pipeline_inference(rank, world_size, port, args):
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
-    parser.add_argument("-tp", "--tp_size", type=int, default=2, help="Tensor parallel size")
-    parser.add_argument("-pp", "--pp_size", type=int, default=2, help="Tensor parallel size")
-    parser.add_argument("-b", "--batch_size", type=int, default=8, help="Maximum batch size")
-    parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
-    parser.add_argument("--max_output_len", type=int, default=16, help="Maximum output length")
+    parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
+    parser.add_argument("-pp", "--pp_size", type=int, default=1, help="Tensor parallel size")
+    parser.add_argument("-b", "--batch_size", type=int, default=64, help="Maximum batch size")
+    parser.add_argument("--max_input_len", type=int, default=512, help="Maximum input length")
+    parser.add_argument("--max_output_len", type=int, default=256, help="Maximum output length")
     parser.add_argument("--micro_batch_size", type=int, default=2, help="Micro batch size")
 
     args = parser.parse_args()
diff --git a/requirements/requirements-infer.txt b/requirements/requirements-infer.txt
index 461dcb23b..3151504df 100644
--- a/requirements/requirements-infer.txt
+++ b/requirements/requirements-infer.txt
@@ -2,6 +2,6 @@ transformers==4.34.0
 packaging
 ninja
 auto-gptq==0.5.0
-git+https://github.com/ModelTC/lightllm.git@28c1267cfca536b7b4f28e921e03de735b003039
+git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8
 git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
 git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9