[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)

* [fix] GQA calling of flash decoding triton

* fix kv cache alloc shape

* fix rotary triton - GQA

* fix sequence max length assigning

* Sequence max length logic

* fix scheduling and spec-dec

* skip without import error

* fix pytest - skip without ImportError

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yuanheng Zhao
2024-04-23 13:09:55 +08:00
committed by GitHub
parent ccf72797e3
commit 5d4c1fe8f5
9 changed files with 183 additions and 194 deletions

View File

@@ -133,8 +133,9 @@ def check_spec_dec(num_layers, max_length):
assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None
max_new_tokens = max_length - dummy_inputs.size(1)
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
# test GLIDE model
glide_config = GlideLlamaConfig(
@@ -152,7 +153,7 @@ def check_spec_dec(num_layers, max_length):
engine.clear_spec_dec()
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
@@ -186,7 +187,7 @@ def test_tp_engine(prompt_template, do_sample):
@parameterize("num_layers", [1])
@parameterize("max_length", [100])
@parameterize("max_length", [64])
def test_spec_dec(num_layers, max_length):
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)