mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. (#5365)
* fused the gate and up proj in mlp * fix code styles * opt auto_grad * rollback test_inference_engine.py * modifications based on the review feedback. * fix bugs in flash attn * Change reshape to view * fix test_rmsnorm_triton.py
This commit is contained in:
@@ -115,8 +115,9 @@ class InferenceEngine:
|
||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
||||
|
||||
Returns:
|
||||
nn.Module: _description_
|
||||
nn.Module: The model optimized by Shardformer.
|
||||
"""
|
||||
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
@@ -149,25 +150,25 @@ class InferenceEngine:
|
||||
Returns:
|
||||
List[str]: Inference result returned by one generation.
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
self.generation_config = generation_config
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
|
||||
self.generation_config = generation_config
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
output_seqs_list = []
|
||||
output_tokens_list = []
|
||||
|
||||
output_seqs_list = []
|
||||
output_tokens_list = []
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_seqs_list += self.step()
|
||||
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_seqs_list += self.step()
|
||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||
|
||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||
for seq in output_seqs_list:
|
||||
output_tokens_list.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
for seq in output_seqs_list:
|
||||
output_tokens_list.append(seq.input_token_id + seq.output_token_id)
|
||||
output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True)
|
||||
|
||||
output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True)
|
||||
|
||||
return output_str
|
||||
return output_str
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user