mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[Inference/Feat] Add quant kvcache interface (#5700)
* add quant kvcache interface * delete unused output * complete args comments
This commit is contained in:
@@ -88,6 +88,7 @@ class InferenceConfig:
|
||||
max_output_len (int): Maximum output length, defaults to 256.
|
||||
max_input_len (int): Maximum input length, defaults to 256.
|
||||
dtype (Union[str, torch.dtype]): The data type for weights and activations.
|
||||
kv_cache_dtype (Optional[str]): The data type of kv_cache, defaults to None.
|
||||
prompt_template (Optional[str]): The prompt template for generation, defaults to None.
|
||||
do_sample (bool): Whether to use sampling for generation, defaults to False.
|
||||
beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
|
||||
@@ -122,6 +123,7 @@ class InferenceConfig:
|
||||
|
||||
# general configs
|
||||
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
|
||||
kv_cache_dtype: Optional[str] = None
|
||||
|
||||
# generation configs
|
||||
prompt_template: Optional[str] = None
|
||||
@@ -177,6 +179,12 @@ class InferenceConfig:
|
||||
self.dtype in _ALLOWED_DTYPES
|
||||
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
|
||||
|
||||
if self.kv_cache_dtype:
|
||||
assert (
|
||||
self.use_cuda_kernel and self.kv_cache_dtype == "fp8"
|
||||
), f"FP8 kv_cache is only supported with use_cuda_kernel open now"
|
||||
self.kv_cache_dtype = torch.uint8
|
||||
|
||||
# skip using casting when the data type is float32
|
||||
if self.dtype == torch.float32:
|
||||
self.high_precision = False
|
||||
|
Reference in New Issue
Block a user