mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
Pass inference model shard configs for module init
Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
@@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
@@ -72,8 +72,9 @@ class InferenceEngine:
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||
|
||||
self.init_model(model_or_path, model_policy)
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
@@ -98,9 +99,7 @@ class InferenceEngine:
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||
|
||||
# TODO: when use_spec_dec set to True, users should pass drafter_model configs into InferenceEngine
|
||||
# We can add a SpecDecConfig class to store these configs.
|
||||
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
@@ -109,9 +108,10 @@ class InferenceEngine:
|
||||
self._verify_args()
|
||||
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
@@ -120,6 +120,7 @@ class InferenceEngine:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
|
||||
if isinstance(model_or_path, str):
|
||||
@@ -133,7 +134,7 @@ class InferenceEngine:
|
||||
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||
else:
|
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
@@ -176,6 +177,7 @@ class InferenceEngine:
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_shard_infer_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
@@ -296,6 +298,7 @@ class InferenceEngine:
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
) -> nn.Module:
|
||||
@@ -321,6 +324,7 @@ class InferenceEngine:
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
@@ -357,8 +361,7 @@ class InferenceEngine:
|
||||
engine.clear_spec_dec()
|
||||
```
|
||||
"""
|
||||
self.logger.warning(f"Current method will be deprecated soon. To use speculative decoding, please set `use_spec_dec` in `InferenceConfig` instead.")
|
||||
|
||||
|
||||
if drafter_model is None and self.drafter is None:
|
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||
if n_spec_tokens is not None:
|
||||
|
Reference in New Issue
Block a user