mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
Refactor modeling by adding attention backend
Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
@@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelInferenceConfig
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
@@ -72,8 +72,9 @@ class InferenceEngine:
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_inference_config = inference_config.to_model_inference_config()
|
||||
|
||||
self.init_model(model_or_path, model_policy)
|
||||
self.init_model(model_or_path, model_policy, self.model_inference_config)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
@@ -97,7 +98,10 @@ class InferenceEngine:
|
||||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = False
|
||||
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||
|
||||
# TODO: when use_spec_dec set to True, users should pass drafter_model configs into InferenceEngine
|
||||
# We can add a SpecDecConfig class to store these configs.
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
@@ -105,13 +109,19 @@ class InferenceEngine:
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_inference_config: ModelInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
"""
|
||||
|
||||
if isinstance(model_or_path, str):
|
||||
@@ -124,7 +134,8 @@ class InferenceEngine:
|
||||
# the model load process in the future.
|
||||
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||
else:
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
@@ -167,6 +178,7 @@ class InferenceEngine:
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_inference_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
@@ -187,7 +199,7 @@ class InferenceEngine:
|
||||
# assert if_has_index_file, "the model path is invalid"
|
||||
# cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
@@ -287,6 +299,7 @@ class InferenceEngine:
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_inference_config: ModelInferenceConfig,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
) -> nn.Module:
|
||||
@@ -348,6 +361,8 @@ class InferenceEngine:
|
||||
engine.clear_spec_dec()
|
||||
```
|
||||
"""
|
||||
self.logger.warning(f"Current method will be deprecated soon. To use speculative decoding, please set `use_spec_dec` in `InferenceConfig` instead.")
|
||||
|
||||
if drafter_model is None and self.drafter is None:
|
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||
if n_spec_tokens is not None:
|
||||
@@ -517,19 +532,19 @@ class InferenceEngine:
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> List[str]:
|
||||
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
|
||||
"""
|
||||
Executing the inference step.
|
||||
|
||||
Args:
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
return_token_ids (bool): Whether to return output token ids. Defaults to False.
|
||||
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
|
||||
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
|
||||
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[str]: Inference result returned by one generation.
|
||||
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
|
||||
"""
|
||||
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
|
Reference in New Issue
Block a user