mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[inference] added inference template (#5375)
This commit is contained in:
@@ -23,6 +23,12 @@ _DTYPE_MAPPING = {
|
||||
_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
|
||||
_DEFAULT_PROMPT_TEMPLATES = {
|
||||
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
|
||||
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
"""The inference configuration.
|
||||
@@ -44,6 +50,7 @@ class InferenceConfig:
|
||||
pad_input: Whether to pad all inputs to the max length.
|
||||
quant_mode (Optional[str]): Quantization mode.
|
||||
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
||||
prompt_template (Optional[str]): The prompt template for formatting the input text. Some built-in templates include 'llama' and 'vicuna'. Otherwise, the template should contain '{input_text}' for formatting the input text.
|
||||
"""
|
||||
|
||||
micro_batch_size: int = 1
|
||||
@@ -62,6 +69,7 @@ class InferenceConfig:
|
||||
pad_input: bool = False
|
||||
quant_mode: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
prompt_template: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._verify_config()
|
||||
@@ -85,3 +93,15 @@ class InferenceConfig:
|
||||
assert (
|
||||
self.tp_size * self.pp_size == dist.get_world_size()
|
||||
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
|
||||
# check prompt template
|
||||
if self.prompt_template is None:
|
||||
return
|
||||
|
||||
if self.prompt_template in _DEFAULT_PROMPT_TEMPLATES:
|
||||
self.prompt_template = _DEFAULT_PROMPT_TEMPLATES[self.prompt_template]
|
||||
else:
|
||||
# make sure the template can be formatted with input_text
|
||||
assert (
|
||||
"{input_text}" in self.prompt_template
|
||||
), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '"
|
||||
|
Reference in New Issue
Block a user