Pass inference model shard configs for module init

Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
char-1ee
2024-06-07 08:28:19 +00:00
parent eec77e5702
commit 5f398fc000
11 changed files with 238 additions and 136 deletions

View File

@@ -10,6 +10,7 @@ import torch
from transformers.generation import GenerationConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import can_use_flash_attn2
GibiByte = 1024**3
@@ -312,13 +313,14 @@ class InferenceConfig(RPC_PARAM):
meta_config[type] = getattr(model_config, type)
return GenerationConfig.from_dict(meta_config)
def to_model_inference_config(self) -> "ModelInferenceConfig":
model_inference_config = ModelInferenceConfig(
def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
use_flash_attn = can_use_flash_attn2(self.dtype)
model_inference_config = ModelShardInferenceConfig(
dtype=self.dtype,
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_cuda_graph=self.use_cuda_graph,
use_flash_attn=use_flash_attn,
)
return model_inference_config
@@ -374,21 +376,20 @@ class InferenceConfig(RPC_PARAM):
inference_config = cls(**inference_config_args)
return inference_config
@dataclass
class ModelInferenceConfig():
class ModelShardInferenceConfig:
"""
Configurations used when initializing/sharding model for inference.
Configurations used during init of module for inference modeling.
Args:
dtype (torch.dtype): The data type for weights and activations.
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_spec_dec (bool): Indicate whether to use speculative decoding.
use_flash_attn (bool): Indicate whether to use flash attention.
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
"""
dtype: torch.dtype = None
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
use_cuda_graph: bool = False