mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
Fixed a bug in the inference frame
This commit is contained in:
committed by
FrankLeeeee
parent
86853a37d5
commit
62fd08ee44
@@ -49,6 +49,7 @@ class InferenceEngine:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model.config
|
||||
self.device = torch.device("cuda")
|
||||
|
||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
@@ -76,6 +77,7 @@ class InferenceEngine:
|
||||
self.logger = get_dist_logger(__name__)
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
|
||||
self.counter = count()
|
||||
|
||||
def _verify_config(self) -> None:
|
||||
@@ -170,7 +172,11 @@ class InferenceEngine:
|
||||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts)["input_ids"]
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
|
||||
|
||||
assert (
|
||||
len(prompts_token_ids[0]) < self.inference_config.max_input_len
|
||||
), "The length of input prompts must be less than max_input_len."
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
@@ -183,13 +189,14 @@ class InferenceEngine:
|
||||
prompt = None
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device)
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
None,
|
||||
block_table,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.inference_config.max_output_len,
|
||||
)
|
||||
@@ -211,14 +218,15 @@ class InferenceEngine:
|
||||
self.logger.info("Running generation step")
|
||||
|
||||
output_list = []
|
||||
batch, k_cache, v_cache = self.request_handler.schedule()
|
||||
batch = self.request_handler.schedule()
|
||||
|
||||
logits = self.model(
|
||||
batch,
|
||||
k_cache,
|
||||
v_cache,
|
||||
self.k_cahce,
|
||||
self.v_cache,
|
||||
)
|
||||
self.request_handler.search_tokens(logits, self.generation_config)
|
||||
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
|
Reference in New Issue
Block a user