Fix tests and naming

Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
char-1ee
2024-06-03 05:41:32 +00:00
parent 04386d9eff
commit eec77e5702
5 changed files with 154 additions and 250 deletions

View File

@@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelInferenceConfig
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
@@ -72,9 +72,8 @@ class InferenceEngine:
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_inference_config = inference_config.to_model_inference_config()
self.init_model(model_or_path, model_policy, self.model_inference_config)
self.init_model(model_or_path, model_policy)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
@@ -113,7 +112,6 @@ class InferenceEngine:
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_inference_config: ModelInferenceConfig = None,
):
"""
Shard model or/and Load weight
@@ -178,7 +176,6 @@ class InferenceEngine:
self.model = self._shardformer(
model,
model_policy,
model_inference_config,
None,
tp_group=tp_group,
)
@@ -299,7 +296,6 @@ class InferenceEngine:
self,
model: nn.Module,
model_policy: Policy,
model_inference_config: ModelInferenceConfig,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module: