mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
@@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelInferenceConfig
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
@@ -72,9 +72,8 @@ class InferenceEngine:
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_inference_config = inference_config.to_model_inference_config()
|
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_inference_config)
|
||||
self.init_model(model_or_path, model_policy)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
@@ -113,7 +112,6 @@ class InferenceEngine:
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_inference_config: ModelInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
@@ -178,7 +176,6 @@ class InferenceEngine:
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_inference_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
@@ -299,7 +296,6 @@ class InferenceEngine:
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_inference_config: ModelInferenceConfig,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
) -> nn.Module:
|
||||
|
Reference in New Issue
Block a user