[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)

* [fix] GQA calling of flash decoding triton

* fix kv cache alloc shape

* fix rotary triton - GQA

* fix sequence max length assigning

* Sequence max length logic

* fix scheduling and spec-dec

* skip without import error

* fix pytest - skip without ImportError

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yuanheng Zhao
2024-04-23 13:09:55 +08:00
committed by GitHub
parent ccf72797e3
commit 5d4c1fe8f5
9 changed files with 183 additions and 194 deletions

View File

@@ -314,10 +314,11 @@ class RequestHandler:
def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
for seq in batch.seqs_li:
if (
seq.output_token_id[-1] == generation_config.eos_token_id
or seq.output_len >= generation_config.max_length
):
max_length = generation_config.max_length
max_new_tokens = generation_config.max_new_tokens
if max_length is not None:
max_new_tokens = max_length - seq.input_len
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
seq.mark_finished()
def check_unfinished_seqs(self) -> bool: