[inference] simplified config verification (#5346)

* [inference] simplified config verification

* polish

* polish
This commit is contained in:
Frank Lee 2024-02-01 15:31:01 +08:00 committed by GitHub
parent df0aa49585
commit f8e456d202
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 60 deletions

View File

@ -14,23 +14,32 @@ GibiByte = 1024**3
logger = logging.Logger(__name__) logger = logging.Logger(__name__)
_DTYPE_MAPPING = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
@dataclass @dataclass
class InferenceConfig: class InferenceConfig:
"""The inference configuration. """The inference configuration.
Args: Args:
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1. micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
max_batch_size (int): Maximum batch size. max_batch_size (int): Maximum batch size, defaults to 8.
max_output_len (int): Maximum output length. max_output_len (int): Maximum output length, defaults to 256.
max_input_len (int): Maximum input length. max_input_len (int): Maximum input length, defaults to 256.
block_size (int): The number of blocks in a logical block. block_size (int): The number of blocks in a logical block, defaults to 16.
dtype (Union[str, torch.dtype]): The data type for weights and activations. dtype (Union[str, torch.dtype]): The data type for weights and activations.
tp_size (int): Tensor parallel size. tp_size (int): Tensor parallel size, defaults to 1.
pp_size (int): Pipeline parallel size. pp_size (int): Pipeline parallel size, defaults to 1.
beam_width (int): The maximum beam width used to initialize KV Cache. beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, defaults to 1.2. We will do a step of prefill
when the actual value exceeds this ratio. when the actual value exceeds this ratio.
pad_input: Whether to pad all inputs to the max length. pad_input: Whether to pad all inputs to the max length.
quant_mode (Optional[str]): Quantization mode. quant_mode (Optional[str]): Quantization mode.
@ -43,7 +52,7 @@ class InferenceConfig:
max_output_len: int = 256 max_output_len: int = 256
max_input_len: int = 256 max_input_len: int = 256
block_size: int = 16 block_size: int = 16
dtype: Union[str, torch.dtype] = torch.float32 dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
tp_size: int = 1 tp_size: int = 1
pp_size: int = 1 pp_size: int = 1
# TODO: beam search is not support for now # TODO: beam search is not support for now
@ -55,57 +64,24 @@ class InferenceConfig:
revision: Optional[str] = None revision: Optional[str] = None
def __post_init__(self): def __post_init__(self):
self._init_batch_size()
self._verify_config() self._verify_config()
self._get_dtype()
def _init_batch_size(self):
"""
MAX_BATCH_SIZE is set to acurately utilize the memory of gpu.
We take a simple method to determine it by GPU memory size, user can still set it manually.
"""
if self.max_batch_size is not None:
# already set by user
return
device = torch.device("cuda")
total_mem = torch.cuda.get_device_properties(device).total_memory // GibiByte
self.max_batch_size = 8
if 40 < total_mem <= 60:
self.max_batch_size = 16
elif 60 < total_mem <= 80:
self.max_batch_size = 32
logger.info(
f"The maximum batch size is automatically set to {self.max_batch_size} as no value is provided by the user."
)
def _verify_config(self) -> None: def _verify_config(self) -> None:
""" """
Verify the input config Verify the input config
""" """
# check dtype
if isinstance(self.dtype, str):
# convert string dtype to torch dtype
assert (
self.dtype in _DTYPE_MAPPING
), f"Expected the dtype string argument to be in {list(_DTYPE_MAPPING.keys())} but found an unknown dtype: {self.dtype}"
self.dtype = _DTYPE_MAPPING[self.dtype]
assert (
self.dtype in _ALLOWED_DTYPES
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
# check distributed
assert ( assert (
self.tp_size * self.pp_size == dist.get_world_size() self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})" ), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert self.dtype in [
"fp16",
"fp32",
"bf16",
torch.float32,
torch.float16,
torch.bfloat16,
], f"dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16, but got {self.dtype}."
assert self.quant_mode in [
"smoothquant",
"gptq",
None,
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
def _get_dtype(self) -> None:
if self.dtype == "fp32" or self.dtype == torch.float32:
self.dtype = torch.float32
elif self.dtype == "fp16" or self.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16

View File

@ -21,11 +21,15 @@ def setup_seed(seed):
def check_inference_engine(test_cai=False): def check_inference_engine(test_cai=False):
setup_seed(20) setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM( model = (
LlamaConfig( LlamaForCausalLM(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
) )
).cuda() .cuda()
.half()
)
model = model.eval() model = model.eval()
@ -70,7 +74,7 @@ def run_dist(rank, world_size, port):
transformer_outputs = check_inference_engine(False) transformer_outputs = check_inference_engine(False)
for s1, s2 in zip(cai_outputs, transformer_outputs): for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2 assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
@pytest.mark.dist @pytest.mark.dist