mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[Inference]Lazy Init Support (#5785)
* lazy init support * lazy init llama support * :lazy init support for baichuan * aligh rpc * add note for baichuan --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -70,7 +70,6 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
|
||||
"""
|
||||
ParallelModule.__init__(self)
|
||||
self.o_proj = attn_oproj
|
||||
|
||||
self.config = config
|
||||
self.num_heads = num_heads
|
||||
@@ -78,6 +77,7 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.process_group = process_group
|
||||
self.W_pack = W_pack
|
||||
self.o_proj = attn_oproj
|
||||
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
|
||||
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
||||
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
||||
|
@@ -284,6 +284,10 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
||||
self.gate_up_weight = nn.Parameter(
|
||||
torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0)
|
||||
)
|
||||
self.gate_up_dict = {
|
||||
"gate_proj.weight": None,
|
||||
"up_proj.weight": None,
|
||||
} # used and delattr in load/shard of gate/up weight
|
||||
self.down_proj = mlp_dproj
|
||||
self.process_group = process_group
|
||||
|
||||
@@ -321,44 +325,47 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
||||
):
|
||||
# NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight)
|
||||
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
if hasattr(self, "gate_up_dict"):
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
key = "gate_up_weight"
|
||||
k1 = "gate_proj.weight"
|
||||
k2 = "up_proj.weight"
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
for weight_name in self.gate_up_dict:
|
||||
prefix_weight_name = prefix + weight_name
|
||||
if prefix_weight_name in state_dict.keys():
|
||||
w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
|
||||
self.gate_up_dict[weight_name] = w.T
|
||||
|
||||
gate_w = state_dict[prefix + k1]
|
||||
up_w = state_dict[prefix + k2]
|
||||
if None not in self.gate_up_dict.values():
|
||||
# we've got all the weights of gate/up
|
||||
gate_up_w = torch.stack(list(self.gate_up_dict.values()), dim=0)
|
||||
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec)
|
||||
up_w = distribute_tensor(up_w, device_mesh, sharding_spec)
|
||||
input_param = nn.Parameter(
|
||||
gate_up_w
|
||||
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
|
||||
gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0)
|
||||
key = "gate_up_weight"
|
||||
param = local_state.get(key, None)
|
||||
|
||||
input_param = nn.Parameter(
|
||||
gate_up_w
|
||||
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
param = local_state[key]
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
del self.gate_up_dict
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
strict = False # to avoid unexpected_keys
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
@@ -429,7 +436,15 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
self.helper_layout = (
|
||||
attn_qproj_w.dist_layout
|
||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
||||
self.qkv_dict = {
|
||||
"q_proj.weight": None,
|
||||
"k_proj.weight": None,
|
||||
"v_proj.weight": None,
|
||||
} # used and delattr in load/shard of qkv weight
|
||||
else:
|
||||
self.helper_layout = (
|
||||
attn_qproj_w.dist_layout
|
||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
||||
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
|
||||
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
|
||||
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
|
||||
@@ -577,49 +592,83 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
if self.num_heads == self.num_key_value_heads:
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
|
||||
if self.num_heads == self.num_key_value_heads and hasattr(self, "qkv_dict"):
|
||||
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
key = "qkv_weight"
|
||||
k1 = "q_proj.weight"
|
||||
k2 = "k_proj.weight"
|
||||
k3 = "v_proj.weight"
|
||||
q_w = state_dict[prefix + k1]
|
||||
k_w = state_dict[prefix + k2]
|
||||
v_w = state_dict[prefix + k3]
|
||||
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
|
||||
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
|
||||
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
|
||||
# NOTE(@lry89757) We will load the sharded checkpoint file according to the weight map from *.index.json
|
||||
# Here we need the weight of q,k,v to stack the weights of q,k,v into one qkv weight.
|
||||
# Unfortunately, it is highly like that all weights of q,k,v are not in the same sharded checkpoint file(like meta-llama/llama3-70B)
|
||||
# so here we will stack them when we really collect all the three weights.
|
||||
for weight_name in self.qkv_dict:
|
||||
prefix_weight_name = prefix + weight_name
|
||||
if prefix_weight_name in state_dict.keys():
|
||||
w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
|
||||
self.qkv_dict[weight_name] = w.T
|
||||
|
||||
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
|
||||
if None not in self.qkv_dict.values():
|
||||
# we've got all the weights of q, k, v
|
||||
qkv_w = torch.stack(list(self.qkv_dict.values()), dim=0)
|
||||
|
||||
input_param = nn.Parameter(
|
||||
qkv_w
|
||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
input_param = nn.Parameter(
|
||||
qkv_w
|
||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
|
||||
param = local_state[key]
|
||||
param = local_state[key]
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
del self.qkv_dict
|
||||
|
||||
else:
|
||||
|
||||
def _load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight"):
|
||||
if prefix + origin_weight_name in state_dict.keys():
|
||||
attn_qproj_w = state_dict[prefix + origin_weight_name]
|
||||
w = distribute_tensor(attn_qproj_w, device_mesh, sharding_spec)
|
||||
input_param = nn.Parameter(w.T)
|
||||
param = local_state[local_weight_name]
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
key = local_weight_name
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
if prefix + "q_proj.weight" in state_dict.keys():
|
||||
_load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight")
|
||||
|
||||
if prefix + "k_proj.weight" in state_dict.keys():
|
||||
_load(origin_weight_name="k_proj.weight", local_weight_name="k_proj_weight")
|
||||
|
||||
if prefix + "v_proj.weight" in state_dict.keys():
|
||||
_load(origin_weight_name="v_proj.weight", local_weight_name="v_proj_weight")
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
Reference in New Issue
Block a user