Add padding llama model

This commit is contained in:
yuehuayingxueluo
2023-12-25 14:07:43 +08:00
committed by FrankLeeeee
parent 0e616462a7
commit 86853a37d5
5 changed files with 262 additions and 11 deletions

View File

@@ -46,6 +46,7 @@ class InferenceEngine:
) -> None:
assert inference_config, "Please provide inference_config."
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.inference_config = inference_config
self.model_config = model.config
@@ -169,9 +170,7 @@ class InferenceEngine:
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = []
for prompt in prompts:
prompts_token_ids.append(self.tokenizer.encode(prompt))
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts)["input_ids"]
prompts_num = len(prompts_token_ids)
@@ -212,11 +211,14 @@ class InferenceEngine:
self.logger.info("Running generation step")
output_list = []
self.request_handler.schedule()
batch, k_cache, v_cache = self.request_handler.schedule()
# Uncomment if the development of RequestHandler is completed.
# logits = self.model(batch)
# self.request_handler.search_tokens(logits, self.generation_config)
logits = self.model(
batch,
k_cache,
v_cache,
)
self.request_handler.search_tokens(logits, self.generation_config)
finished_sequences = self.request_handler.update()