mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
Pass inference model shard configs for module init
Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
@@ -10,6 +10,7 @@ import torch
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
|
||||
GibiByte = 1024**3
|
||||
|
||||
@@ -312,13 +313,14 @@ class InferenceConfig(RPC_PARAM):
|
||||
meta_config[type] = getattr(model_config, type)
|
||||
|
||||
return GenerationConfig.from_dict(meta_config)
|
||||
|
||||
def to_model_inference_config(self) -> "ModelInferenceConfig":
|
||||
model_inference_config = ModelInferenceConfig(
|
||||
|
||||
def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
|
||||
use_flash_attn = can_use_flash_attn2(self.dtype)
|
||||
model_inference_config = ModelShardInferenceConfig(
|
||||
dtype=self.dtype,
|
||||
use_cuda_kernel=self.use_cuda_kernel,
|
||||
use_spec_dec=self.use_spec_dec,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
return model_inference_config
|
||||
|
||||
@@ -374,21 +376,20 @@ class InferenceConfig(RPC_PARAM):
|
||||
inference_config = cls(**inference_config_args)
|
||||
return inference_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInferenceConfig():
|
||||
class ModelShardInferenceConfig:
|
||||
"""
|
||||
Configurations used when initializing/sharding model for inference.
|
||||
|
||||
Configurations used during init of module for inference modeling.
|
||||
|
||||
Args:
|
||||
dtype (torch.dtype): The data type for weights and activations.
|
||||
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
|
||||
use_spec_dec (bool): Indicate whether to use speculative decoding.
|
||||
use_flash_attn (bool): Indicate whether to use flash attention.
|
||||
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
|
||||
"""
|
||||
|
||||
dtype: torch.dtype = None
|
||||
use_cuda_kernel: bool = False
|
||||
use_spec_dec: bool = False
|
||||
use_flash_attn: bool = False
|
||||
use_cuda_graph: bool = False
|
||||
|
Reference in New Issue
Block a user