[Inference]Lazy Init Support (#5785)

* lazy init support

* lazy init llama support

* :lazy init support for baichuan

* aligh rpc

* add note for baichuan

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Runyu Lu
2024-06-27 18:02:15 +08:00
committed by GitHub
parent d9d5e7ea1f
commit 3c7cda0c9a
7 changed files with 205 additions and 105 deletions

View File

@@ -1,8 +1,10 @@
from typing import List, Union
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import Linear1D_Col
from colossalai.shardformer.layer.parallel_module import ParallelModule
@@ -12,17 +14,51 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
module.in_features = module.weight.size(1)
module.out_features = module.weight.size(0)
module.bias = None
module.weight.data = nn.functional.normalize(
module.weight
) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
) # NOTE(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
return Linear1D_Col.from_native_module(
module,
process_group,
*args,
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
lmhead_1d = BaichuanLMHeadLinear1D_Col(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return lmhead_1d
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
state_dict[prefix + "weight"] = nn.functional.normalize(state_dict[prefix + "weight"])
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

View File

@@ -70,7 +70,6 @@ class NopadBaichuanAttention(ParallelModule):
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
"""
ParallelModule.__init__(self)
self.o_proj = attn_oproj
self.config = config
self.num_heads = num_heads
@@ -78,6 +77,7 @@ class NopadBaichuanAttention(ParallelModule):
self.head_dim = self.hidden_size // self.num_heads
self.process_group = process_group
self.W_pack = W_pack
self.o_proj = attn_oproj
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)

View File

@@ -284,6 +284,10 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
self.gate_up_weight = nn.Parameter(
torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0)
)
self.gate_up_dict = {
"gate_proj.weight": None,
"up_proj.weight": None,
} # used and delattr in load/shard of gate/up weight
self.down_proj = mlp_dproj
self.process_group = process_group
@@ -321,44 +325,47 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
):
# NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight)
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
if hasattr(self, "gate_up_dict"):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
key = "gate_up_weight"
k1 = "gate_proj.weight"
k2 = "up_proj.weight"
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
for weight_name in self.gate_up_dict:
prefix_weight_name = prefix + weight_name
if prefix_weight_name in state_dict.keys():
w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
self.gate_up_dict[weight_name] = w.T
gate_w = state_dict[prefix + k1]
up_w = state_dict[prefix + k2]
if None not in self.gate_up_dict.values():
# we've got all the weights of gate/up
gate_up_w = torch.stack(list(self.gate_up_dict.values()), dim=0)
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec)
up_w = distribute_tensor(up_w, device_mesh, sharding_spec)
input_param = nn.Parameter(
gate_up_w
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0)
key = "gate_up_weight"
param = local_state.get(key, None)
input_param = nn.Parameter(
gate_up_w
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
param = local_state[key]
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
del self.gate_up_dict
strict = False # to avoid unexpected_keys
strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
@@ -429,7 +436,15 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
self.helper_layout = (
attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
self.qkv_dict = {
"q_proj.weight": None,
"k_proj.weight": None,
"v_proj.weight": None,
} # used and delattr in load/shard of qkv weight
else:
self.helper_layout = (
attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
@@ -577,49 +592,83 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if self.num_heads == self.num_key_value_heads:
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
if self.num_heads == self.num_key_value_heads and hasattr(self, "qkv_dict"):
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
key = "qkv_weight"
k1 = "q_proj.weight"
k2 = "k_proj.weight"
k3 = "v_proj.weight"
q_w = state_dict[prefix + k1]
k_w = state_dict[prefix + k2]
v_w = state_dict[prefix + k3]
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
# NOTE(@lry89757) We will load the sharded checkpoint file according to the weight map from *.index.json
# Here we need the weight of q,k,v to stack the weights of q,k,v into one qkv weight.
# Unfortunately, it is highly like that all weights of q,k,v are not in the same sharded checkpoint file(like meta-llama/llama3-70B)
# so here we will stack them when we really collect all the three weights.
for weight_name in self.qkv_dict:
prefix_weight_name = prefix + weight_name
if prefix_weight_name in state_dict.keys():
w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
self.qkv_dict[weight_name] = w.T
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
if None not in self.qkv_dict.values():
# we've got all the weights of q, k, v
qkv_w = torch.stack(list(self.qkv_dict.values()), dim=0)
input_param = nn.Parameter(
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
input_param = nn.Parameter(
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
param = local_state[key]
param = local_state[key]
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
strict = False # to avoid unexpected_keys
del self.qkv_dict
else:
def _load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight"):
if prefix + origin_weight_name in state_dict.keys():
attn_qproj_w = state_dict[prefix + origin_weight_name]
w = distribute_tensor(attn_qproj_w, device_mesh, sharding_spec)
input_param = nn.Parameter(w.T)
param = local_state[local_weight_name]
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
key = local_weight_name
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
if prefix + "q_proj.weight" in state_dict.keys():
_load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight")
if prefix + "k_proj.weight" in state_dict.keys():
_load(origin_weight_name="k_proj.weight", local_weight_name="k_proj_weight")
if prefix + "v_proj.weight" in state_dict.keys():
_load(origin_weight_name="v_proj.weight", local_weight_name="v_proj_weight")
strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)