[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. (#5365)

* fused the gate and up proj in mlp

* fix code styles

* opt auto_grad

* rollback test_inference_engine.py

* modifications based on the review feedback.

* fix bugs in flash attn

* Change reshape to view

* fix test_rmsnorm_triton.py
This commit is contained in:
yuehuayingxueluo
2024-02-06 19:38:25 +08:00
committed by GitHub
parent 1dedb57747
commit 35382a7fbf
10 changed files with 484 additions and 50 deletions

View File

@@ -115,8 +115,9 @@ class InferenceEngine:
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: _description_
nn.Module: The model optimized by Shardformer.
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
@@ -149,25 +150,25 @@ class InferenceEngine:
Returns:
List[str]: Inference result returned by one generation.
"""
with torch.inference_mode():
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)
output_seqs_list = []
output_tokens_list = []
output_seqs_list = []
output_tokens_list = []
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
for seq in output_seqs_list:
output_tokens_list.append(seq.input_token_id + seq.output_token_id)
for seq in output_seqs_list:
output_tokens_list.append(seq.input_token_id + seq.output_token_id)
output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True)
output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True)
return output_str
return output_str
def add_request(
self,