[inference] refactor examples and fix schedule (#5077)

* [setup] refactor infer setup

* [hotfix] fix infenrece behavior on 1 1 gpu

* [exmaple] refactor inference examples
This commit is contained in:
Hongxin Liu
2023-11-21 10:46:03 +08:00
committed by GitHub
parent 4e3959d316
commit 1cd7efc520
9 changed files with 209 additions and 274 deletions

View File

@@ -33,13 +33,16 @@ class InferenceEngine:
Args:
tp_size (int): the size of tensor parallelism.
pp_size (int): the size of pipeline parallelism.
dtype (str): the data type of the model, should be one of 'fp16', 'fp32', 'bf16'.
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
micro_batch_size (int): the micro batch size.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. It will be determined by the model type if not provided.
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
max_batch_size (int): the maximum batch size.
max_input_len (int): the maximum input length.
max_output_len (int): the maximum output length.
quant (str): the quantization method, should be one of 'smoothquant', 'gptq', None.
verbose (bool): whether to return the time cost of each step.
"""