mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
Refactor modeling by adding attention backend
Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
@@ -169,7 +169,8 @@ class InferenceConfig(RPC_PARAM):
|
||||
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
|
||||
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
|
||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||
use_spec_dec (bool): Indicate whether to use speculative decoding, defaults to False.
|
||||
max_n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
||||
block_size (int): The number of blocks in a logical block, defaults to 16.
|
||||
tp_size (int): Tensor parallel size, defaults to 1.
|
||||
@@ -214,6 +215,7 @@ class InferenceConfig(RPC_PARAM):
|
||||
ignore_eos: bool = False
|
||||
|
||||
# speculative decoding configs
|
||||
use_spec_dec: bool = False
|
||||
max_n_spec_tokens: int = 5
|
||||
glimpse_large_kv: bool = False
|
||||
|
||||
@@ -310,6 +312,15 @@ class InferenceConfig(RPC_PARAM):
|
||||
meta_config[type] = getattr(model_config, type)
|
||||
|
||||
return GenerationConfig.from_dict(meta_config)
|
||||
|
||||
def to_model_inference_config(self) -> "ModelInferenceConfig":
|
||||
model_inference_config = ModelInferenceConfig(
|
||||
dtype=self.dtype,
|
||||
use_cuda_kernel=self.use_cuda_kernel,
|
||||
use_spec_dec=self.use_spec_dec,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
return model_inference_config
|
||||
|
||||
def to_rpc_param(self) -> dict:
|
||||
kwargs = {
|
||||
@@ -362,3 +373,22 @@ class InferenceConfig(RPC_PARAM):
|
||||
# Set the attributes from the parsed arguments.
|
||||
inference_config = cls(**inference_config_args)
|
||||
return inference_config
|
||||
|
||||
@dataclass
|
||||
class ModelInferenceConfig():
|
||||
"""
|
||||
Configurations used when initializing/sharding model for inference.
|
||||
|
||||
Args:
|
||||
dtype (torch.dtype): The data type for weights and activations.
|
||||
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
|
||||
use_spec_dec (bool): Indicate whether to use speculative decoding.
|
||||
use_flash_attn (bool): Indicate whether to use flash attention.
|
||||
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
|
||||
"""
|
||||
dtype: torch.dtype = None
|
||||
use_cuda_kernel: bool = False
|
||||
use_spec_dec: bool = False
|
||||
use_flash_attn: bool = False
|
||||
use_cuda_graph: bool = False
|
||||
|
Reference in New Issue
Block a user