mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
166
colossalai/legacy/inference/dynamic_batching/io_struct.py
Normal file
166
colossalai/legacy/inference/dynamic_batching/io_struct.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Adapted from https://github.com/ModelTC/lightllm
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from .sampling_params import SamplingParams
|
||||
|
||||
|
||||
class Req:
|
||||
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""):
|
||||
self.request_id = request_id
|
||||
self.prompt_ids = prompt_ids
|
||||
self.input_len = len(prompt_ids)
|
||||
self.max_output_len = sample_params.max_new_tokens
|
||||
self.sample_params = sample_params
|
||||
self.output_ids = []
|
||||
self.output_metadata_list = []
|
||||
self.has_generate_finished = False
|
||||
self.aborted = False
|
||||
self.prompts = prompts
|
||||
|
||||
def to_rpc_obj(self):
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
"input_id": self.prompt_ids,
|
||||
"output_len": self.max_output_len,
|
||||
"sampling_param": self.sample_params.to_dict(),
|
||||
}
|
||||
|
||||
def stop_sequences_matched(self):
|
||||
# should we add stpp sequences to the sample params?
|
||||
if self.sample_params.stop_sequences is not None:
|
||||
for stop_token_ids in self.sample_params.stop_sequences:
|
||||
stop_len = len(stop_token_ids)
|
||||
if (
|
||||
stop_len > 0
|
||||
and len(self.output_ids) >= stop_len
|
||||
and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len))
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, "
|
||||
|
||||
|
||||
class Batch:
|
||||
def __init__(self, batch_id, reqs: List[Req]):
|
||||
self.batch_id = batch_id
|
||||
self.reqs = reqs
|
||||
self.id_to_reqs = {req.request_id: req for req in reqs}
|
||||
|
||||
def input_tokens(self):
|
||||
batch_input_tokens = 0
|
||||
for req in self.reqs:
|
||||
batch_input_tokens += req.input_len
|
||||
return batch_input_tokens
|
||||
|
||||
def calcu_max_tokens(self):
|
||||
tokens = 0
|
||||
for req in self.reqs:
|
||||
tokens += req.input_len + req.max_output_len
|
||||
return tokens
|
||||
|
||||
def calcu_used_tokens(self):
|
||||
tokens = 0
|
||||
for req in self.reqs:
|
||||
tokens += req.input_len + len(req.output_ids)
|
||||
return tokens
|
||||
|
||||
def mark_finished_req(self, eos_id, engine_max_output_len):
|
||||
has_new_finish = False
|
||||
for req in self.reqs:
|
||||
if req.stop_sequences_matched():
|
||||
req.has_generate_finished = True
|
||||
has_new_finish = True
|
||||
if len(req.output_ids) >= engine_max_output_len:
|
||||
req.has_generate_finished = True
|
||||
has_new_finish = True
|
||||
if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False:
|
||||
req.has_generate_finished = True
|
||||
has_new_finish = True
|
||||
if len(req.output_ids) >= req.max_output_len or req.aborted:
|
||||
req.has_generate_finished = True
|
||||
has_new_finish = True
|
||||
return has_new_finish
|
||||
|
||||
def filter_finished(self) -> List[Req]:
|
||||
"""
|
||||
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
|
||||
"""
|
||||
# TODO: the logic of return should be defined here.
|
||||
unfinished_req = []
|
||||
finished_req = []
|
||||
for req in self.reqs:
|
||||
if not req.has_generate_finished:
|
||||
unfinished_req.append(req)
|
||||
else:
|
||||
finished_req.append(req)
|
||||
self.reqs = unfinished_req
|
||||
self.id_to_reqs = {req.request_id: req for req in self.reqs}
|
||||
return finished_req
|
||||
|
||||
def is_clear(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def merge(self, mini_batch):
|
||||
for _req in mini_batch.reqs:
|
||||
self.reqs.append(_req)
|
||||
self.id_to_reqs = {req.request_id: req for req in self.reqs}
|
||||
return
|
||||
|
||||
def __repr__(self):
|
||||
return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, "
|
||||
|
||||
def __len__(self):
|
||||
return len(self.reqs)
|
||||
|
||||
|
||||
class BatchTokenIdOut:
|
||||
def __init__(self):
|
||||
self.reqs_infs: List[
|
||||
Tuple[str, int, Dict, bool, bool]
|
||||
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
|
||||
|
||||
|
||||
class BatchStrOut:
|
||||
def __init__(self):
|
||||
self.reqs_infs: List[
|
||||
Tuple[str, str, Dict, bool, bool]
|
||||
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
|
||||
|
||||
|
||||
class AbortReq:
|
||||
def __init__(self, req_id):
|
||||
self.req_id = req_id
|
||||
|
||||
|
||||
class RequestOutput:
|
||||
"""The output data of a request to the LLM.
|
||||
|
||||
Args:
|
||||
request_id: The unique ID of the request.
|
||||
prompt: The prompt string of the request.
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
outputs: The output sequences of the request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
outputs,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.outputs = outputs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"RequestOutput(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"outputs={self.outputs}, "
|
||||
)
|
Reference in New Issue
Block a user