mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[Inference]Add Nopadding Llama Modeling (#5327)
* add nopadding llama modeling * add nopadding_llama.py * rm unused codes * fix bugs in test_xine_copy.py * fix code style
This commit is contained in:
@@ -57,7 +57,11 @@ class InferenceEngine:
|
||||
model.to(self.dtype)
|
||||
|
||||
if model_policy is None:
|
||||
model_policy = model_policy_map[self.model_config.model_type]()
|
||||
if self.inference_config.pad_input:
|
||||
model_type = "padding_" + self.model_config.model_type
|
||||
else:
|
||||
model_type = "nopadding_" + self.model_config.model_type
|
||||
model_policy = model_policy_map[model_type]()
|
||||
|
||||
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
|
||||
|
||||
@@ -168,7 +172,9 @@ class InferenceEngine:
|
||||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
if isinstance(prompts_token_ids, list):
|
||||
pass
|
||||
@@ -237,7 +243,9 @@ class InferenceEngine:
|
||||
self.v_cache,
|
||||
)
|
||||
|
||||
logits = logits[:, -1, :]
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
|
Reference in New Issue
Block a user