mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Inference]Lazy Init Support (#5785)
* lazy init support * lazy init llama support * :lazy init support for baichuan * aligh rpc * add note for baichuan --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -24,8 +24,9 @@ from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.inference.utils import get_model_size
|
||||
from colossalai.inference.utils import get_model_size, has_index_file
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
@@ -122,16 +123,24 @@ class InferenceEngine:
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
|
||||
pretrained_path = None
|
||||
if isinstance(model_or_path, str):
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
if arch in _supported_models.keys():
|
||||
# NOTE(lry89757) Currently we load the model using transformers-api,
|
||||
# but we will use lazy tensor and checkpoint io to accelerate
|
||||
# the model load process in the future.
|
||||
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||
if arch is "BaichuanForCausalLM":
|
||||
self.logger.warning(
|
||||
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
||||
)
|
||||
ctx = LazyInitContext(default_device="cuda")
|
||||
with ctx:
|
||||
model = _supported_models[arch].from_pretrained(
|
||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||
)
|
||||
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||
else:
|
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
@@ -189,14 +198,13 @@ class InferenceEngine:
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
|
||||
# if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
|
||||
# from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
if pretrained_path:
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
# cpt_io = InferCheckpoint_io()
|
||||
# if_has_index_file, model_index_file = has_index_file(model_or_path)
|
||||
# assert if_has_index_file, "the model path is invalid"
|
||||
# cpt_io.load_model(self.model, model_index_file)
|
||||
cpt_io = InferCheckpoint_io()
|
||||
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
||||
assert if_has_index_file, "the model path is invalid"
|
||||
cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
|
@@ -73,7 +73,9 @@ class RPCInferenceEngine(InferenceEngine):
|
||||
|
||||
try:
|
||||
if isinstance(model_or_path, str):
|
||||
self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||
)
|
||||
elif isinstance(model_or_path, nn.Module):
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"
|
||||
|
Reference in New Issue
Block a user