diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index a1b54fa1c..f0918c88c 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -24,8 +24,9 @@ from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.sampler import search_tokens from colossalai.inference.spec import Drafter, GlideInput from colossalai.inference.struct import Sequence -from colossalai.inference.utils import get_model_size +from colossalai.inference.utils import get_model_size, has_index_file from colossalai.interface import ModelWrapper +from colossalai.lazy import LazyInitContext from colossalai.logging import get_dist_logger from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -122,16 +123,24 @@ class InferenceEngine: model_inference_config: the configuration for modeling initialization when inference. model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. """ - + pretrained_path = None if isinstance(model_or_path, str): + import colossalai.interface.pretrained as pretrained_utils + try: - hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) arch = getattr(hf_config, "architectures")[0] if arch in _supported_models.keys(): - # NOTE(lry89757) Currently we load the model using transformers-api, - # but we will use lazy tensor and checkpoint io to accelerate - # the model load process in the future. - model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True) + if arch is "BaichuanForCausalLM": + self.logger.warning( + "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" + ) + ctx = LazyInitContext(default_device="cuda") + with ctx: + model = _supported_models[arch].from_pretrained( + model_or_path, trust_remote_code=True, torch_dtype=self.dtype + ) + pretrained_path = pretrained_utils.get_pretrained_path(model) else: # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate raise ValueError(f"Model {arch} is not supported.") @@ -189,14 +198,13 @@ class InferenceEngine: f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" ) - # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor - # if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM): - # from colossalai.inference.core.plugin import InferCheckpoint_io + if pretrained_path: + from colossalai.inference.core.plugin import InferCheckpoint_io - # cpt_io = InferCheckpoint_io() - # if_has_index_file, model_index_file = has_index_file(model_or_path) - # assert if_has_index_file, "the model path is invalid" - # cpt_io.load_model(self.model, model_index_file) + cpt_io = InferCheckpoint_io() + if_has_index_file, model_index_file = has_index_file(pretrained_path) + assert if_has_index_file, "the model path is invalid" + cpt_io.load_model(self.model, model_index_file) free_gpu_memory, _ = torch.cuda.mem_get_info() peak_memory = init_gpu_memory - free_gpu_memory diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py index 439c4b0b5..87222a744 100644 --- a/colossalai/inference/core/rpc_engine.py +++ b/colossalai/inference/core/rpc_engine.py @@ -73,7 +73,9 @@ class RPCInferenceEngine(InferenceEngine): try: if isinstance(model_or_path, str): - self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + self.model_config = AutoConfig.from_pretrained( + model_or_path, trust_remote_code=True, torch_dtype=self.dtype + ) elif isinstance(model_or_path, nn.Module): self.logger.error( f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n" diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py index 913b8667d..a5199cb74 100644 --- a/colossalai/inference/executor/rpc_worker.py +++ b/colossalai/inference/executor/rpc_worker.py @@ -18,8 +18,9 @@ from colossalai.inference.modeling.policy import ( model_policy_map, ) from colossalai.inference.sampler import search_tokens -from colossalai.inference.utils import get_model_size +from colossalai.inference.utils import get_model_size, has_index_file from colossalai.interface import ModelWrapper +from colossalai.lazy import LazyInitContext from colossalai.logging import get_dist_logger from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -178,20 +179,23 @@ class rpcWorkerService(rpyc.Service): model_policy (Policy): the policy to replace the model """ + pretrained_path = None if isinstance(model_or_path, str): - # is_local = os.path.isdir(model_or_path) + import colossalai.interface.pretrained as pretrained_utils + try: - hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) arch = getattr(hf_config, "architectures")[0] - # NOTE(lry89757) Currently we load the model using transformers-api, - # but we will use lazy tensor and checkpoint io to accelerate - # the model load process in the future. - model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) - # if is_local: - # model = _SUPPORTED_MODELS[arch](hf_config) - # else: - # # load the real checkpoint - # model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) + if arch is "BaichuanForCausalLM": + self.logger.warning( + "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" + ) + ctx = LazyInitContext(default_device="cuda") + with ctx: + model = _SUPPORTED_MODELS[arch].from_pretrained( + model_or_path, trust_remote_code=True, torch_dtype=self.dtype + ) + pretrained_path = pretrained_utils.get_pretrained_path(model) except Exception as e: logger.error( f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" @@ -240,14 +244,13 @@ class rpcWorkerService(rpyc.Service): f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" ) - # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor - # if isinstance(model_or_path, str) and is_local: - # from colossalai.inference.core.plugin import InferCheckpoint_io + if pretrained_path: + from colossalai.inference.core.plugin import InferCheckpoint_io - # cpt_io = InferCheckpoint_io() - # if_has_index_file, model_index_file = has_index_file(model_or_path) - # assert if_has_index_file, "the model path is invalid" - # cpt_io.load_model(self.model, model_index_file) + cpt_io = InferCheckpoint_io() + if_has_index_file, model_index_file = has_index_file(pretrained_path) + assert if_has_index_file, "the model path is invalid" + cpt_io.load_model(self.model, model_index_file) free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() peak_memory = init_gpu_memory - free_gpu_memory diff --git a/colossalai/inference/modeling/layers/baichuan_tp_linear.py b/colossalai/inference/modeling/layers/baichuan_tp_linear.py index 50806a14b..75260f59b 100644 --- a/colossalai/inference/modeling/layers/baichuan_tp_linear.py +++ b/colossalai/inference/modeling/layers/baichuan_tp_linear.py @@ -1,8 +1,10 @@ from typing import List, Union +import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup +from colossalai.lazy import LazyInitContext from colossalai.shardformer.layer import Linear1D_Col from colossalai.shardformer.layer.parallel_module import ParallelModule @@ -12,17 +14,51 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col): def from_native_module( module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs ) -> ParallelModule: + LazyInitContext.materialize(module) module.in_features = module.weight.size(1) module.out_features = module.weight.size(0) module.bias = None module.weight.data = nn.functional.normalize( module.weight - ) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight. + ) # NOTE(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight. # So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue. - return Linear1D_Col.from_native_module( - module, - process_group, - *args, + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + if out_features < tp_size: + return module + + if out_features % tp_size != 0: + raise ValueError( + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + lmhead_1d = BaichuanLMHeadLinear1D_Col( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, **kwargs, ) + + return lmhead_1d + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + state_dict[prefix + "weight"] = nn.functional.normalize(state_dict[prefix + "weight"]) + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index 3bab671c4..dfc53d9f6 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -70,7 +70,6 @@ class NopadBaichuanAttention(ParallelModule): attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None. """ ParallelModule.__init__(self) - self.o_proj = attn_oproj self.config = config self.num_heads = num_heads @@ -78,6 +77,7 @@ class NopadBaichuanAttention(ParallelModule): self.head_dim = self.hidden_size // self.num_heads self.process_group = process_group self.W_pack = W_pack + self.o_proj = attn_oproj self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel self.attention_backend = get_attention_backend(model_shard_infer_config) self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 445ec59ce..c7c7473ac 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -284,6 +284,10 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule): self.gate_up_weight = nn.Parameter( torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0) ) + self.gate_up_dict = { + "gate_proj.weight": None, + "up_proj.weight": None, + } # used and delattr in load/shard of gate/up weight self.down_proj = mlp_dproj self.process_group = process_group @@ -321,44 +325,47 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule): ): # NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight) - for hook in self._load_state_dict_pre_hooks.values(): - hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + if hasattr(self, "gate_up_dict"): + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} - local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) - local_state = {k: v for k, v in local_name_params if v is not None} + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} - key = "gate_up_weight" - k1 = "gate_proj.weight" - k2 = "up_proj.weight" + device_mesh = self.helper_layout.device_mesh + sharding_spec = self.helper_layout.sharding_spec + for weight_name in self.gate_up_dict: + prefix_weight_name = prefix + weight_name + if prefix_weight_name in state_dict.keys(): + w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec) + self.gate_up_dict[weight_name] = w.T - gate_w = state_dict[prefix + k1] - up_w = state_dict[prefix + k2] + if None not in self.gate_up_dict.values(): + # we've got all the weights of gate/up + gate_up_w = torch.stack(list(self.gate_up_dict.values()), dim=0) - device_mesh = self.helper_layout.device_mesh - sharding_spec = self.helper_layout.sharding_spec - gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec) - up_w = distribute_tensor(up_w, device_mesh, sharding_spec) + input_param = nn.Parameter( + gate_up_w + ) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) - gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0) + key = "gate_up_weight" + param = local_state.get(key, None) - input_param = nn.Parameter( - gate_up_w - ) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) - param = local_state[key] + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) - try: - with torch.no_grad(): - param.copy_(input_param) - except Exception as ex: - error_msgs.append( - 'While copying the parameter named "{}", ' - "whose dimensions in the model are {} and " - "whose dimensions in the checkpoint are {}, " - "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) - ) + del self.gate_up_dict - strict = False # to avoid unexpected_keys + strict = False # to avoid unexpected_keys super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) @@ -429,7 +436,15 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): self.helper_layout = ( attn_qproj_w.dist_layout ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) + self.qkv_dict = { + "q_proj.weight": None, + "k_proj.weight": None, + "v_proj.weight": None, + } # used and delattr in load/shard of qkv weight else: + self.helper_layout = ( + attn_qproj_w.dist_layout + ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous()) self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous()) self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous()) @@ -577,49 +592,83 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): - if self.num_heads == self.num_key_value_heads: + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + device_mesh = self.helper_layout.device_mesh + sharding_spec = self.helper_layout.sharding_spec + + if self.num_heads == self.num_key_value_heads and hasattr(self, "qkv_dict"): # NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight) - for hook in self._load_state_dict_pre_hooks.values(): - hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} - local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) - local_state = {k: v for k, v in local_name_params if v is not None} - key = "qkv_weight" - k1 = "q_proj.weight" - k2 = "k_proj.weight" - k3 = "v_proj.weight" - q_w = state_dict[prefix + k1] - k_w = state_dict[prefix + k2] - v_w = state_dict[prefix + k3] - device_mesh = self.helper_layout.device_mesh - sharding_spec = self.helper_layout.sharding_spec - q_w = distribute_tensor(q_w, device_mesh, sharding_spec) - k_w = distribute_tensor(k_w, device_mesh, sharding_spec) - v_w = distribute_tensor(v_w, device_mesh, sharding_spec) + # NOTE(@lry89757) We will load the sharded checkpoint file according to the weight map from *.index.json + # Here we need the weight of q,k,v to stack the weights of q,k,v into one qkv weight. + # Unfortunately, it is highly like that all weights of q,k,v are not in the same sharded checkpoint file(like meta-llama/llama3-70B) + # so here we will stack them when we really collect all the three weights. + for weight_name in self.qkv_dict: + prefix_weight_name = prefix + weight_name + if prefix_weight_name in state_dict.keys(): + w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec) + self.qkv_dict[weight_name] = w.T - qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0) + if None not in self.qkv_dict.values(): + # we've got all the weights of q, k, v + qkv_w = torch.stack(list(self.qkv_dict.values()), dim=0) - input_param = nn.Parameter( - qkv_w - ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) + input_param = nn.Parameter( + qkv_w + ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) - param = local_state[key] + param = local_state[key] - try: - with torch.no_grad(): - param.copy_(input_param) - except Exception as ex: - error_msgs.append( - 'While copying the parameter named "{}", ' - "whose dimensions in the model are {} and " - "whose dimensions in the checkpoint are {}, " - "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) - ) + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) - strict = False # to avoid unexpected_keys + del self.qkv_dict + + else: + + def _load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight"): + if prefix + origin_weight_name in state_dict.keys(): + attn_qproj_w = state_dict[prefix + origin_weight_name] + w = distribute_tensor(attn_qproj_w, device_mesh, sharding_spec) + input_param = nn.Parameter(w.T) + param = local_state[local_weight_name] + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + key = local_weight_name + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) + + if prefix + "q_proj.weight" in state_dict.keys(): + _load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight") + + if prefix + "k_proj.weight" in state_dict.keys(): + _load(origin_weight_name="k_proj.weight", local_weight_name="k_proj_weight") + + if prefix + "v_proj.weight" in state_dict.keys(): + _load(origin_weight_name="v_proj.weight", local_weight_name="v_proj_weight") + + strict = False # to avoid unexpected_keys super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index dc3634238..0f6595a7c 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -674,6 +674,8 @@ class FusedLinear1D_Col(ParallelModule): process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight. """ + LazyInitContext.materialize(module) + # get the attributes in_features = module.in_features out_features = module.out_features