[Inference]Add Nopadding Llama Modeling (#5327)

* add nopadding llama modeling

* add nopadding_llama.py

* rm unused codes

* fix bugs in test_xine_copy.py

* fix code style
This commit is contained in:
yuehuayingxueluo
2024-01-30 10:31:46 +08:00
committed by GitHub
parent c7c104cb7c
commit e8f0642f28
9 changed files with 386 additions and 49 deletions

View File

@@ -57,7 +57,11 @@ class InferenceEngine:
model.to(self.dtype)
if model_policy is None:
model_policy = model_policy_map[self.model_config.model_type]()
if self.inference_config.pad_input:
model_type = "padding_" + self.model_config.model_type
else:
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]()
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
@@ -168,7 +172,9 @@ class InferenceEngine:
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids"
]
if isinstance(prompts_token_ids, list):
pass
@@ -237,7 +243,9 @@ class InferenceEngine:
self.v_cache,
)
logits = logits[:, -1, :]
if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()