Pass inference model shard configs for module init

Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
char-1ee
2024-06-07 08:28:19 +00:00
parent eec77e5702
commit 5f398fc000
11 changed files with 238 additions and 136 deletions

View File

@@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
@@ -72,8 +72,9 @@ class InferenceEngine:
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.init_model(model_or_path, model_policy)
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
@@ -98,9 +99,7 @@ class InferenceEngine:
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = self.inference_config.use_spec_dec
# TODO: when use_spec_dec set to True, users should pass drafter_model configs into InferenceEngine
# We can add a SpecDecConfig class to store these configs.
self.drafter_model = None
self.drafter = None
self.use_glide = False
@@ -109,9 +108,10 @@ class InferenceEngine:
self._verify_args()
def init_model(
self,
model_or_path: Union[nn.Module, str],
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
@@ -120,6 +120,7 @@ class InferenceEngine:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
if isinstance(model_or_path, str):
@@ -133,7 +134,7 @@ class InferenceEngine:
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
raise ValueError(f"Model {arch} is not supported.")
except Exception as e:
self.logger.error(
@@ -176,6 +177,7 @@ class InferenceEngine:
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
@@ -296,6 +298,7 @@ class InferenceEngine:
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
@@ -321,6 +324,7 @@ class InferenceEngine:
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
@@ -357,8 +361,7 @@ class InferenceEngine:
engine.clear_spec_dec()
```
"""
self.logger.warning(f"Current method will be deprecated soon. To use speculative decoding, please set `use_spec_dec` in `InferenceConfig` instead.")
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None: