mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[Fix] Llama3 Load/Omit CheckpointIO Temporarily (#5717)
* Fix Llama3 Load error * Omit Checkpoint IO Temporarily
This commit is contained in:
@@ -646,48 +646,49 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
if self.num_heads == self.num_key_value_heads:
|
||||
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
key = "qkv_weight"
|
||||
k1 = "q_proj.weight"
|
||||
k2 = "k_proj.weight"
|
||||
k3 = "v_proj.weight"
|
||||
q_w = state_dict[prefix + k1]
|
||||
k_w = state_dict[prefix + k2]
|
||||
v_w = state_dict[prefix + k3]
|
||||
key = "qkv_weight"
|
||||
k1 = "q_proj.weight"
|
||||
k2 = "k_proj.weight"
|
||||
k3 = "v_proj.weight"
|
||||
q_w = state_dict[prefix + k1]
|
||||
k_w = state_dict[prefix + k2]
|
||||
v_w = state_dict[prefix + k3]
|
||||
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
|
||||
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
|
||||
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
|
||||
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
|
||||
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
|
||||
|
||||
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
|
||||
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
|
||||
|
||||
input_param = nn.Parameter(
|
||||
qkv_w
|
||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
input_param = nn.Parameter(
|
||||
qkv_w
|
||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
|
||||
param = local_state[key]
|
||||
param = local_state[key]
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
strict = False # to avoid unexpected_keys
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
Reference in New Issue
Block a user