mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)
* [fix] GQA calling of flash decoding triton * fix kv cache alloc shape * fix rotary triton - GQA * fix sequence max length assigning * Sequence max length logic * fix scheduling and spec-dec * skip without import error * fix pytest - skip without ImportError --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -518,7 +518,13 @@ class InferenceEngine:
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
self.add_request(
|
||||
request_ids=request_ids,
|
||||
prompts=prompts,
|
||||
prompts_token_ids=prompts_token_ids,
|
||||
**gen_config_dict,
|
||||
)
|
||||
|
||||
output_seqs_list = []
|
||||
total_tokens_list = []
|
||||
@@ -573,6 +579,7 @@ class InferenceEngine:
|
||||
request_ids: List[int] = None,
|
||||
prompts: List[str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Add requests.
|
||||
@@ -629,6 +636,13 @@ class InferenceEngine:
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
|
||||
max_length = kwargs.get("max_length", None)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", None)
|
||||
if max_length is None and max_new_tokens is None:
|
||||
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
|
||||
elif max_length is not None:
|
||||
max_new_tokens = max_length - len(prompts_token_ids[i])
|
||||
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
@@ -637,7 +651,7 @@ class InferenceEngine:
|
||||
None,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
self.inference_config.max_output_len,
|
||||
max_output_len=max_new_tokens,
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
|
||||
|
Reference in New Issue
Block a user