mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
This commit is contained in:
@@ -17,6 +17,7 @@ def check_config_and_inference():
|
||||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
@@ -28,6 +29,7 @@ def check_config_and_inference():
|
||||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
@@ -39,6 +41,7 @@ def check_config_and_inference():
|
||||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
sequence.mark_running()
|
||||
@@ -51,7 +54,12 @@ def check_config_and_inference():
|
||||
assert sequence.output_len == 0
|
||||
assert sequence.check_finish() == False
|
||||
|
||||
batch = BatchInfo(is_prompts=False)
|
||||
batch = BatchInfo(
|
||||
max_batch_size=8,
|
||||
kv_max_split_num=16,
|
||||
num_heads=2,
|
||||
head_dim=128,
|
||||
)
|
||||
batch.init_batch([sequence])
|
||||
batch.add_seqs([sequence2, sequence3])
|
||||
batch.add_seqs([sequence])
|
||||
|
@@ -3,8 +3,7 @@ import random
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
@@ -22,8 +21,8 @@ def setup_seed(seed):
|
||||
def check_inference_engine(test_cai=False):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = transformers.LlamaForCausalLM(
|
||||
transformers.LlamaConfig(
|
||||
model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||
)
|
||||
).cuda()
|
||||
@@ -81,4 +80,4 @@ def test_inference_engine():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_inference_engine()
|
||||
test_inference_engine()
|
||||
|
@@ -20,6 +20,7 @@ def check_running_list():
|
||||
input_token_id=[1, 2, 3],
|
||||
block_size=16,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
sample_params=None,
|
||||
block_table=1,
|
||||
)
|
||||
@@ -56,6 +57,7 @@ def check_request_handler():
|
||||
input_token_id=[1, 2, 3, 4, 5],
|
||||
block_size=16,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
sample_params=None,
|
||||
block_table=torch.tensor([-1, -1]),
|
||||
)
|
||||
|
@@ -91,6 +91,7 @@ def test_flash_decoding(
|
||||
max_seq_len_in_b = kv_seq_lengths.max().item()
|
||||
# The maximum block length splitted on kv should be the kv cache block size
|
||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
|
||||
mid_output = torch.empty(
|
||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||
)
|
||||
@@ -106,6 +107,7 @@ def test_flash_decoding(
|
||||
block_tables,
|
||||
block_size,
|
||||
max_seq_len_in_b,
|
||||
output,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
@@ -184,6 +186,7 @@ def bench_kernel(
|
||||
block_tables = block_tables.to(device=device)
|
||||
# the maximum block length splitted on kv should be the kv cache block size
|
||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
|
||||
mid_output = torch.empty(
|
||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||
)
|
||||
@@ -199,6 +202,7 @@ def bench_kernel(
|
||||
block_tables,
|
||||
block_size,
|
||||
max_seq_len_in_b,
|
||||
output,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
|
Reference in New Issue
Block a user