[inference]Optimize the usage of the mid tensors space in flash attn (#5304)

* opt flash attn

* opt tmp tensor

* fix benchmark_llama

* fix code style

* fix None logic for output tensor

* fix adapted to get_xine_cache

* add comment

* fix ci bugs

* fix some codes

* rm duplicated codes

* rm duplicated codes

* fix code style

* add _get_dtype in config.py
This commit is contained in:
yuehuayingxueluo
2024-01-26 14:00:10 +08:00
committed by GitHub
parent af8359c430
commit 4f28cb43c0
16 changed files with 199 additions and 57 deletions

View File

@@ -20,6 +20,7 @@ def check_running_list():
input_token_id=[1, 2, 3],
block_size=16,
eos_token_id=0,
pad_token_id=0,
sample_params=None,
block_table=1,
)
@@ -56,6 +57,7 @@ def check_request_handler():
input_token_id=[1, 2, 3, 4, 5],
block_size=16,
eos_token_id=0,
pad_token_id=0,
sample_params=None,
block_table=torch.tensor([-1, -1]),
)