mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[Inference] User Experience: update the logic of default tokenizer and generation config. (#5337)
* add * fix * fix * pause * fix * fix pytest * align * fix * license * fix * fix * fix readme * fix some bugs * remove tokenizer config
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
GibiByte = 1024**3
|
||||
|
||||
@@ -60,15 +61,22 @@ class InferenceConfig:
|
||||
max_input_len: int = 256
|
||||
block_size: int = 16
|
||||
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
|
||||
|
||||
tp_size: int = 1
|
||||
pp_size: int = 1
|
||||
# TODO: beam search is not support for now
|
||||
do_sample: bool = False
|
||||
beam_width: int = 1
|
||||
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||
prefill_ratio: Optional[float] = 1.2
|
||||
pad_input: bool = False
|
||||
quant_mode: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
early_stopping: Optional[bool] = False
|
||||
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
min_p: Optional[float] = None
|
||||
prompt_template: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -93,7 +101,6 @@ class InferenceConfig:
|
||||
assert (
|
||||
self.tp_size * self.pp_size == dist.get_world_size()
|
||||
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
|
||||
# check prompt template
|
||||
if self.prompt_template is None:
|
||||
return
|
||||
@@ -105,3 +112,20 @@ class InferenceConfig:
|
||||
assert (
|
||||
"{input_text}" in self.prompt_template
|
||||
), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '"
|
||||
|
||||
def to_generation_config(self, model_config) -> GenerationConfig:
|
||||
meta_config = {
|
||||
"max_length": self.max_input_len + self.max_output_len,
|
||||
"max_new_tokens": self.max_output_len,
|
||||
"early_stopping": self.early_stopping,
|
||||
"do_sample": self.do_sample,
|
||||
"num_beams": self.beam_width,
|
||||
}
|
||||
for type in ["top_k", "top_p", "min_p"]:
|
||||
if hasattr(self, type):
|
||||
meta_config[type] = getattr(self, type)
|
||||
for type in ["pad_token_id", "bos_token_id", "eos_token_id"]:
|
||||
if hasattr(model_config, type):
|
||||
meta_config[type] = getattr(model_config, type)
|
||||
|
||||
return GenerationConfig.from_dict(meta_config)
|
||||
|
Reference in New Issue
Block a user