mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-31 15:25:21 +00:00
[Fix/Inference] Fix format of input prompts and input model in inference engine (#5395)
* Fix bugs in inference_engine * fix bugs in engine.py * rm CUDA_VISIBLE_DEVICES * add request_ids in generate * fix bug in engine.py * add logger.debug for BatchBucket
This commit is contained in:
parent
2a718c8be8
commit
bc1da87366
@ -447,3 +447,6 @@ class BatchBucket:
|
||||
def fd_inter_tensor(self) -> None:
|
||||
assert self.fd_interm_tensor is not None, "fd_interm_tensor is not provided"
|
||||
return self.fd_interm_tensor
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})"
|
||||
|
@ -57,6 +57,7 @@ class InferenceEngine:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
model = model.eval()
|
||||
model = model.cuda()
|
||||
model.to(self.dtype)
|
||||
|
||||
if model_policy is None:
|
||||
@ -133,12 +134,13 @@ class InferenceEngine:
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model.cuda()
|
||||
return shard_model
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
request_ids: List[int] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> List[str]:
|
||||
@ -148,6 +150,7 @@ class InferenceEngine:
|
||||
Args:
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
return_token_ids (bool): Whether to return output token ids. Defaults to False.
|
||||
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
@ -157,7 +160,7 @@ class InferenceEngine:
|
||||
with torch.inference_mode():
|
||||
self.generation_config = generation_config
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
|
||||
output_seqs_list = []
|
||||
total_tokens_list = []
|
||||
@ -204,7 +207,7 @@ class InferenceEngine:
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
requests_id: List[int] = None,
|
||||
request_ids: List[int] = None,
|
||||
prompts: List[str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
) -> None:
|
||||
@ -212,7 +215,7 @@ class InferenceEngine:
|
||||
Add requests.
|
||||
|
||||
Args:
|
||||
requests_id (List[int], optional): The request ID. Defaults to None.
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
"""
|
||||
@ -223,6 +226,9 @@ class InferenceEngine:
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
@ -245,8 +251,14 @@ class InferenceEngine:
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
for i in range(prompts_num):
|
||||
if requests_id:
|
||||
request_id = requests_id[i]
|
||||
if request_ids:
|
||||
if not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
assert len(request_ids) == prompts_num
|
||||
request_id = request_ids[i]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
if prompts == None:
|
||||
|
@ -157,7 +157,7 @@ class Sequence:
|
||||
f"prompt={self.prompt}, "
|
||||
f"status={self.status.name}, "
|
||||
f"sample_params={self.sample_params}, "
|
||||
f"input_len={self.input_len}),"
|
||||
f"input_len={self.input_len},"
|
||||
f"output_len={self.output_len})"
|
||||
)
|
||||
|
||||
|
@ -27,7 +27,7 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
|
||||
for input_len in 128 512 1024; do
|
||||
for output_len in 128 256; do
|
||||
for bsz in 16 32 64; do
|
||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --model_path "/home/caidi/llama_model/" | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt
|
||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --test_random_weight | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt
|
||||
done
|
||||
done
|
||||
done
|
||||
|
@ -5,8 +5,11 @@ from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
@ -83,6 +86,7 @@ def test_bucket(test_config):
|
||||
num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2
|
||||
)
|
||||
block_tables = bb.add_seqs([seq1, seq2])
|
||||
logger.debug(f"bb information: {bb}")
|
||||
assert block_tables.shape == (2, cache_manager.max_blocks_per_sequence)
|
||||
assert torch.all(block_tables < 0), "Initialized block_tables should be negative values"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user