mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Inference] Fix bug in ChatGLM2 Tensor Parallelism (#5014)
* fix bug * fix * fix multiquery * fix multiquery --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
@@ -77,14 +77,15 @@ class TPInferEngine:
|
||||
)
|
||||
self.layer_num = num_hidden_layers
|
||||
|
||||
self.multi_query_group_num = 0
|
||||
self.multi_query_group_num = model.config.num_attention_heads
|
||||
# default to attention_heads
|
||||
self.multi_query_attention = model.config.multi_query_attention
|
||||
|
||||
if hasattr(model.config, "multi_query_group_num"):
|
||||
self.multi_query_group_num = model.config.multi_query_group_num
|
||||
|
||||
if hasattr(model.config, "num_key_value_heads"):
|
||||
self.multi_query_group_num = model.config.num_key_value_heads
|
||||
|
||||
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
|
||||
self.cache_manager = None
|
||||
|
||||
@@ -107,7 +108,7 @@ class TPInferEngine:
|
||||
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
||||
self.head_num //= self.tp_size # update sharded number of heads
|
||||
|
||||
if self.multi_query_group_num:
|
||||
if self.multi_query_attention:
|
||||
# NOTE the logic of MQA tensor parallelism should be specified.
|
||||
assert (
|
||||
self.multi_query_group_num % self.tp_size == 0
|
||||
|
@@ -395,9 +395,9 @@ class ChatGLM2InferenceForwards:
|
||||
assert use_cache is True, "use_cache should be set to True using this chatglm attention"
|
||||
# hidden_states: original :[sq, b, h] --> this [b, sq, h]
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
|
||||
if self.multi_query_attention:
|
||||
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||
[
|
||||
@@ -437,7 +437,6 @@ class ChatGLM2InferenceForwards:
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||
|
||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||
|
||||
chatglm2_rotary_emb_fwd(
|
||||
@@ -466,10 +465,10 @@ class ChatGLM2InferenceForwards:
|
||||
value_layer = value_layer.reshape(
|
||||
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
|
||||
)
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
# first token generation:
|
||||
# copy key and value calculated in current step to memory manager
|
||||
|
||||
copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
key_layer,
|
||||
@@ -477,8 +476,7 @@ class ChatGLM2InferenceForwards:
|
||||
infer_state.context_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
|
||||
attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
|
||||
|
||||
# NOTE: no bug in context attn fwd (del it )
|
||||
lightllm_llama2_context_attention_fwd(
|
||||
@@ -514,7 +512,7 @@ class ChatGLM2InferenceForwards:
|
||||
)
|
||||
|
||||
# second token and follows
|
||||
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
|
||||
attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
|
||||
: infer_state.decode_mem_end, :, :
|
||||
]
|
||||
@@ -542,6 +540,6 @@ class ChatGLM2InferenceForwards:
|
||||
# =================
|
||||
# Output:[b,sq, h]
|
||||
# =================
|
||||
output = self.dense(attn_output).reshape(batch_size, -1, hidden_size)
|
||||
|
||||
output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size)
|
||||
return output, kv_cache
|
||||
|
@@ -48,7 +48,10 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=SelfAttention
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = (
|
||||
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
# for rmsnorm and others, we need to check the shape
|
||||
return policy
|
||||
|
||||
|
@@ -149,7 +149,6 @@ class Linear1D_Col(ParallelModule):
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
|
@@ -400,7 +400,6 @@ class SelfAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.core_attention = CoreAttention(config, self.layer_number)
|
||||
|
||||
# Output.
|
||||
self.dense = nn.Linear(
|
||||
self.projection_size,
|
||||
|
@@ -104,7 +104,6 @@ class ChatGLMPolicy(Policy):
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
|
@@ -180,7 +180,6 @@ class ModelSharder(object):
|
||||
assert target_module is not None, "target_module should not be None"
|
||||
|
||||
native_sub_module = getattr_(org_layer, suffix, ignore=True)
|
||||
|
||||
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
|
||||
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
|
||||
continue
|
||||
|
Reference in New Issue
Block a user