mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
adapted to pad_context_forward
This commit is contained in:
committed by
FrankLeeeee
parent
47e53eaa1c
commit
fa4fbdbffb
@@ -1,6 +1,5 @@
|
||||
"""
|
||||
Our config consists of one part:
|
||||
1. inference_config: configs for inference, it is a unified api that wraps all the configs for inference.
|
||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -94,9 +93,12 @@ class InferenceConfig:
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
|
||||
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
|
||||
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
|
||||
], f"dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16, but got {self.dtype}."
|
||||
assert self.quant_mode in [
|
||||
"smoothquant",
|
||||
"gptq",
|
||||
None,
|
||||
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
|
||||
assert (
|
||||
self.max_input_len + self.max_output_len <= self.max_seq_len
|
||||
), "The sum of max_input_len and max_output_len must be smaller than max_seq_len."
|
||||
), f"The sum of max_input_len {self.max_input_len} and max_output_len {self.max_output_len} must be smaller than max_seq_len {self.max_seq_len}."
|
||||
|
Reference in New Issue
Block a user