adapted to pad_context_forward

This commit is contained in:
yuehuayingxueluo
2024-01-09 13:52:53 +08:00
committed by FrankLeeeee
parent 47e53eaa1c
commit fa4fbdbffb
9 changed files with 42 additions and 41 deletions

View File

@@ -1,6 +1,5 @@
"""
Our config consists of one part:
1. inference_config: configs for inference, it is a unified api that wraps all the configs for inference.
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
"""
import logging
@@ -94,9 +93,12 @@ class InferenceConfig:
torch.float32,
torch.float16,
torch.bfloat16,
], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16"
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
], f"dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16, but got {self.dtype}."
assert self.quant_mode in [
"smoothquant",
"gptq",
None,
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
assert (
self.max_input_len + self.max_output_len <= self.max_seq_len
), "The sum of max_input_len and max_output_len must be smaller than max_seq_len."
), f"The sum of max_input_len {self.max_input_len} and max_output_len {self.max_output_len} must be smaller than max_seq_len {self.max_seq_len}."