[Inference] User Experience: update the logic of default tokenizer and generation config. (#5337)

* add

* fix

* fix

* pause

* fix

* fix pytest

* align

* fix

* license

* fix

* fix

* fix readme

* fix some bugs

* remove tokenizer config
This commit is contained in:
Jianghai
2024-02-07 17:55:48 +08:00
committed by GitHub
parent 6fb4bcbb24
commit 1f8c7e7046
7 changed files with 62 additions and 23 deletions

View File

@@ -8,6 +8,7 @@ from typing import Optional, Union
import torch
import torch.distributed as dist
from transformers.generation import GenerationConfig
GibiByte = 1024**3
@@ -60,15 +61,22 @@ class InferenceConfig:
max_input_len: int = 256
block_size: int = 16
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
tp_size: int = 1
pp_size: int = 1
# TODO: beam search is not support for now
do_sample: bool = False
beam_width: int = 1
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
prefill_ratio: Optional[float] = 1.2
pad_input: bool = False
quant_mode: Optional[str] = None
revision: Optional[str] = None
early_stopping: Optional[bool] = False
top_k: Optional[int] = None
top_p: Optional[float] = None
min_p: Optional[float] = None
prompt_template: Optional[str] = None
def __post_init__(self):
@@ -93,7 +101,6 @@ class InferenceConfig:
assert (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
# check prompt template
if self.prompt_template is None:
return
@@ -105,3 +112,20 @@ class InferenceConfig:
assert (
"{input_text}" in self.prompt_template
), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '"
def to_generation_config(self, model_config) -> GenerationConfig:
meta_config = {
"max_length": self.max_input_len + self.max_output_len,
"max_new_tokens": self.max_output_len,
"early_stopping": self.early_stopping,
"do_sample": self.do_sample,
"num_beams": self.beam_width,
}
for type in ["top_k", "top_p", "min_p"]:
if hasattr(self, type):
meta_config[type] = getattr(self, type)
for type in ["pad_token_id", "bos_token_id", "eos_token_id"]:
if hasattr(model_config, type):
meta_config[type] = getattr(model_config, type)
return GenerationConfig.from_dict(meta_config)