From 74c47921facd26dbd93172bf887abcad4eab2d5c Mon Sep 17 00:00:00 2001 From: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Date: Tue, 14 May 2024 20:17:43 +0800 Subject: [PATCH] [Fix] Llama3 Load/Omit CheckpointIO Temporarily (#5717) * Fix Llama3 Load error * Omit Checkpoint IO Temporarily --- colossalai/inference/core/engine.py | 26 ++++--- colossalai/inference/executor/rpc_worker.py | 32 +++++---- .../modeling/models/nopadding_llama.py | 69 ++++++++++--------- 3 files changed, 65 insertions(+), 62 deletions(-) diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 7b456b8be..047d7d79f 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -24,7 +24,7 @@ from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.sampler import search_tokens from colossalai.inference.spec import Drafter, GlideInput from colossalai.inference.struct import Sequence -from colossalai.inference.utils import get_model_size, has_index_file +from colossalai.inference.utils import get_model_size from colossalai.interface import ModelWrapper from colossalai.logging import get_dist_logger from colossalai.pipeline.stage_manager import PipelineStageManager @@ -113,18 +113,15 @@ class InferenceEngine: model_policy (Policy): the policy to replace the model """ - casuallm = None if isinstance(model_or_path, str): try: hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) arch = getattr(hf_config, "architectures")[0] if arch in _supported_models.keys(): - casuallm = _supported_models[arch](hf_config) - if isinstance(casuallm, AutoModelForCausalLM): - # NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory. - model = AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half() - else: - model = _supported_models[arch](hf_config) + # NOTE(lry89757) Currently we load the model using transformers-api, + # but we will use lazy tensor and checkpoint io to accelerate + # the model load process in the future. + model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True) else: raise ValueError(f"Model {arch} is not supported.") @@ -175,13 +172,14 @@ class InferenceEngine: f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" ) - if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM): - from colossalai.inference.core.plugin import InferCheckpoint_io + # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor + # if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM): + # from colossalai.inference.core.plugin import InferCheckpoint_io - cpt_io = InferCheckpoint_io() - if_has_index_file, model_index_file = has_index_file(model_or_path) - assert if_has_index_file, "the model path is invalid" - cpt_io.load_model(self.model, model_index_file) + # cpt_io = InferCheckpoint_io() + # if_has_index_file, model_index_file = has_index_file(model_or_path) + # assert if_has_index_file, "the model path is invalid" + # cpt_io.load_model(self.model, model_index_file) free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() peak_memory = init_gpu_memory - free_gpu_memory diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py index 4b84dcc85..7d8350ac0 100644 --- a/colossalai/inference/executor/rpc_worker.py +++ b/colossalai/inference/executor/rpc_worker.py @@ -1,4 +1,3 @@ -import os from typing import List, Tuple, Union import rpyc @@ -19,7 +18,7 @@ from colossalai.inference.modeling.policy import ( model_policy_map, ) from colossalai.inference.sampler import search_tokens -from colossalai.inference.utils import get_model_size, has_index_file +from colossalai.inference.utils import get_model_size from colossalai.interface import ModelWrapper from colossalai.logging import get_dist_logger from colossalai.pipeline.stage_manager import PipelineStageManager @@ -178,15 +177,19 @@ class rpcWorkerService(rpyc.Service): """ if isinstance(model_or_path, str): - is_local = os.path.isdir(model_or_path) + # is_local = os.path.isdir(model_or_path) try: hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) arch = getattr(hf_config, "architectures")[0] - if is_local: - model = _SUPPORTED_MODELS[arch](hf_config) - else: - # load the real checkpoint - model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) + # NOTE(lry89757) Currently we load the model using transformers-api, + # but we will use lazy tensor and checkpoint io to accelerate + # the model load process in the future. + model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) + # if is_local: + # model = _SUPPORTED_MODELS[arch](hf_config) + # else: + # # load the real checkpoint + # model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) except Exception as e: logger.error( f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" @@ -235,13 +238,14 @@ class rpcWorkerService(rpyc.Service): f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" ) - if isinstance(model_or_path, str) and is_local: - from colossalai.inference.core.plugin import InferCheckpoint_io + # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor + # if isinstance(model_or_path, str) and is_local: + # from colossalai.inference.core.plugin import InferCheckpoint_io - cpt_io = InferCheckpoint_io() - if_has_index_file, model_index_file = has_index_file(model_or_path) - assert if_has_index_file, "the model path is invalid" - cpt_io.load_model(self.model, model_index_file) + # cpt_io = InferCheckpoint_io() + # if_has_index_file, model_index_file = has_index_file(model_or_path) + # assert if_has_index_file, "the model path is invalid" + # cpt_io.load_model(self.model, model_index_file) free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() peak_memory = init_gpu_memory - free_gpu_memory diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 9e54b7e26..f6f160eb7 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -646,48 +646,49 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): - # NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight) - for hook in self._load_state_dict_pre_hooks.values(): - hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + if self.num_heads == self.num_key_value_heads: + # NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight) + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} - local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) - local_state = {k: v for k, v in local_name_params if v is not None} + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} - key = "qkv_weight" - k1 = "q_proj.weight" - k2 = "k_proj.weight" - k3 = "v_proj.weight" - q_w = state_dict[prefix + k1] - k_w = state_dict[prefix + k2] - v_w = state_dict[prefix + k3] + key = "qkv_weight" + k1 = "q_proj.weight" + k2 = "k_proj.weight" + k3 = "v_proj.weight" + q_w = state_dict[prefix + k1] + k_w = state_dict[prefix + k2] + v_w = state_dict[prefix + k3] - device_mesh = self.helper_layout.device_mesh - sharding_spec = self.helper_layout.sharding_spec - q_w = distribute_tensor(q_w, device_mesh, sharding_spec) - k_w = distribute_tensor(k_w, device_mesh, sharding_spec) - v_w = distribute_tensor(v_w, device_mesh, sharding_spec) + device_mesh = self.helper_layout.device_mesh + sharding_spec = self.helper_layout.sharding_spec + q_w = distribute_tensor(q_w, device_mesh, sharding_spec) + k_w = distribute_tensor(k_w, device_mesh, sharding_spec) + v_w = distribute_tensor(v_w, device_mesh, sharding_spec) - qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0) + qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0) - input_param = nn.Parameter( - qkv_w - ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) + input_param = nn.Parameter( + qkv_w + ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) - param = local_state[key] + param = local_state[key] - try: - with torch.no_grad(): - param.copy_(input_param) - except Exception as ex: - error_msgs.append( - 'While copying the parameter named "{}", ' - "whose dimensions in the model are {} and " - "whose dimensions in the checkpoint are {}, " - "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) - ) + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) - strict = False # to avoid unexpected_keys + strict = False # to avoid unexpected_keys super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs )