mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[inference] refactor examples and fix schedule (#5077)
* [setup] refactor infer setup * [hotfix] fix infenrece behavior on 1 1 gpu * [exmaple] refactor inference examples
This commit is contained in:
@@ -33,13 +33,16 @@ class InferenceEngine:
|
||||
Args:
|
||||
tp_size (int): the size of tensor parallelism.
|
||||
pp_size (int): the size of pipeline parallelism.
|
||||
dtype (str): the data type of the model, should be one of 'fp16', 'fp32', 'bf16'.
|
||||
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
|
||||
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
|
||||
micro_batch_size (int): the micro batch size.
|
||||
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
max_batch_size (int): the maximum batch size.
|
||||
max_input_len (int): the maximum input length.
|
||||
max_output_len (int): the maximum output length.
|
||||
quant (str): the quantization method, should be one of 'smoothquant', 'gptq', None.
|
||||
verbose (bool): whether to return the time cost of each step.
|
||||
|
||||
"""
|
||||
|
||||
|
Reference in New Issue
Block a user