mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
Add padding llama model
This commit is contained in:
committed by
FrankLeeeee
parent
0e616462a7
commit
86853a37d5
@@ -46,6 +46,7 @@ class InferenceEngine:
|
||||
) -> None:
|
||||
assert inference_config, "Please provide inference_config."
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model.config
|
||||
|
||||
@@ -169,9 +170,7 @@ class InferenceEngine:
|
||||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = []
|
||||
for prompt in prompts:
|
||||
prompts_token_ids.append(self.tokenizer.encode(prompt))
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts)["input_ids"]
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
@@ -212,11 +211,14 @@ class InferenceEngine:
|
||||
self.logger.info("Running generation step")
|
||||
|
||||
output_list = []
|
||||
self.request_handler.schedule()
|
||||
batch, k_cache, v_cache = self.request_handler.schedule()
|
||||
|
||||
# Uncomment if the development of RequestHandler is completed.
|
||||
# logits = self.model(batch)
|
||||
# self.request_handler.search_tokens(logits, self.generation_config)
|
||||
logits = self.model(
|
||||
batch,
|
||||
k_cache,
|
||||
v_cache,
|
||||
)
|
||||
self.request_handler.search_tokens(logits, self.generation_config)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
|
Reference in New Issue
Block a user