[inference]Optimize the usage of the mid tensors space in flash attn (#5304)

* opt flash attn

* opt tmp tensor

* fix benchmark_llama

* fix code style

* fix None logic for output tensor

* fix adapted to get_xine_cache

* add comment

* fix ci bugs

* fix some codes

* rm duplicated codes

* rm duplicated codes

* fix code style

* add _get_dtype in config.py
This commit is contained in:
yuehuayingxueluo
2024-01-26 14:00:10 +08:00
committed by GitHub
parent af8359c430
commit 4f28cb43c0
16 changed files with 199 additions and 57 deletions

View File

@@ -17,6 +17,7 @@ def check_config_and_inference():
sample_params=None,
block_table=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=256,
)
@@ -28,6 +29,7 @@ def check_config_and_inference():
sample_params=None,
block_table=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=256,
)
@@ -39,6 +41,7 @@ def check_config_and_inference():
sample_params=None,
block_table=None,
eos_token_id=2,
pad_token_id=2,
max_output_len=256,
)
sequence.mark_running()
@@ -51,7 +54,12 @@ def check_config_and_inference():
assert sequence.output_len == 0
assert sequence.check_finish() == False
batch = BatchInfo(is_prompts=False)
batch = BatchInfo(
max_batch_size=8,
kv_max_split_num=16,
num_heads=2,
head_dim=128,
)
batch.init_batch([sequence])
batch.add_seqs([sequence2, sequence3])
batch.add_seqs([sequence])

View File

@@ -3,8 +3,7 @@ import random
import numpy as np
import pytest
import torch
import transformers
from transformers import AutoTokenizer, GenerationConfig
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import InferenceConfig
@@ -22,8 +21,8 @@ def setup_seed(seed):
def check_inference_engine(test_cai=False):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
).cuda()
@@ -81,4 +80,4 @@ def test_inference_engine():
if __name__ == "__main__":
test_inference_engine()
test_inference_engine()

View File

@@ -20,6 +20,7 @@ def check_running_list():
input_token_id=[1, 2, 3],
block_size=16,
eos_token_id=0,
pad_token_id=0,
sample_params=None,
block_table=1,
)
@@ -56,6 +57,7 @@ def check_request_handler():
input_token_id=[1, 2, 3, 4, 5],
block_size=16,
eos_token_id=0,
pad_token_id=0,
sample_params=None,
block_table=torch.tensor([-1, -1]),
)

View File

@@ -91,6 +91,7 @@ def test_flash_decoding(
max_seq_len_in_b = kv_seq_lengths.max().item()
# The maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
@@ -106,6 +107,7 @@ def test_flash_decoding(
block_tables,
block_size,
max_seq_len_in_b,
output,
mid_output,
mid_output_lse,
sm_scale=sm_scale,
@@ -184,6 +186,7 @@ def bench_kernel(
block_tables = block_tables.to(device=device)
# the maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
@@ -199,6 +202,7 @@ def bench_kernel(
block_tables,
block_size,
max_seq_len_in_b,
output,
mid_output,
mid_output_lse,
sm_scale=sm_scale,