mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
This commit is contained in:
@@ -17,6 +17,7 @@ def check_config_and_inference():
|
||||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
@@ -28,6 +29,7 @@ def check_config_and_inference():
|
||||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
|
||||
@@ -39,6 +41,7 @@ def check_config_and_inference():
|
||||
sample_params=None,
|
||||
block_table=None,
|
||||
eos_token_id=2,
|
||||
pad_token_id=2,
|
||||
max_output_len=256,
|
||||
)
|
||||
sequence.mark_running()
|
||||
@@ -51,7 +54,12 @@ def check_config_and_inference():
|
||||
assert sequence.output_len == 0
|
||||
assert sequence.check_finish() == False
|
||||
|
||||
batch = BatchInfo(is_prompts=False)
|
||||
batch = BatchInfo(
|
||||
max_batch_size=8,
|
||||
kv_max_split_num=16,
|
||||
num_heads=2,
|
||||
head_dim=128,
|
||||
)
|
||||
batch.init_batch([sequence])
|
||||
batch.add_seqs([sequence2, sequence3])
|
||||
batch.add_seqs([sequence])
|
||||
|
Reference in New Issue
Block a user