mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[Inference] Fix bug in ChatGLM2 Tensor Parallelism (#5014)
* fix bug * fix * fix multiquery * fix multiquery --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
@@ -77,14 +77,15 @@ class TPInferEngine:
|
||||
)
|
||||
self.layer_num = num_hidden_layers
|
||||
|
||||
self.multi_query_group_num = 0
|
||||
self.multi_query_group_num = model.config.num_attention_heads
|
||||
# default to attention_heads
|
||||
self.multi_query_attention = model.config.multi_query_attention
|
||||
|
||||
if hasattr(model.config, "multi_query_group_num"):
|
||||
self.multi_query_group_num = model.config.multi_query_group_num
|
||||
|
||||
if hasattr(model.config, "num_key_value_heads"):
|
||||
self.multi_query_group_num = model.config.num_key_value_heads
|
||||
|
||||
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
|
||||
self.cache_manager = None
|
||||
|
||||
@@ -107,7 +108,7 @@ class TPInferEngine:
|
||||
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
|
||||
self.head_num //= self.tp_size # update sharded number of heads
|
||||
|
||||
if self.multi_query_group_num:
|
||||
if self.multi_query_attention:
|
||||
# NOTE the logic of MQA tensor parallelism should be specified.
|
||||
assert (
|
||||
self.multi_query_group_num % self.tp_size == 0
|
||||
|
@@ -395,9 +395,9 @@ class ChatGLM2InferenceForwards:
|
||||
assert use_cache is True, "use_cache should be set to True using this chatglm attention"
|
||||
# hidden_states: original :[sq, b, h] --> this [b, sq, h]
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
|
||||
if self.multi_query_attention:
|
||||
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||
[
|
||||
@@ -437,7 +437,6 @@ class ChatGLM2InferenceForwards:
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||
|
||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||
|
||||
chatglm2_rotary_emb_fwd(
|
||||
@@ -466,10 +465,10 @@ class ChatGLM2InferenceForwards:
|
||||
value_layer = value_layer.reshape(
|
||||
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
|
||||
)
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
# first token generation:
|
||||
# copy key and value calculated in current step to memory manager
|
||||
|
||||
copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
key_layer,
|
||||
@@ -477,8 +476,7 @@ class ChatGLM2InferenceForwards:
|
||||
infer_state.context_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
|
||||
attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
|
||||
|
||||
# NOTE: no bug in context attn fwd (del it )
|
||||
lightllm_llama2_context_attention_fwd(
|
||||
@@ -514,7 +512,7 @@ class ChatGLM2InferenceForwards:
|
||||
)
|
||||
|
||||
# second token and follows
|
||||
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
|
||||
attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
|
||||
: infer_state.decode_mem_end, :, :
|
||||
]
|
||||
@@ -542,6 +540,6 @@ class ChatGLM2InferenceForwards:
|
||||
# =================
|
||||
# Output:[b,sq, h]
|
||||
# =================
|
||||
output = self.dense(attn_output).reshape(batch_size, -1, hidden_size)
|
||||
|
||||
output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size)
|
||||
return output, kv_cache
|
||||
|
@@ -48,7 +48,10 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=SelfAttention
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = (
|
||||
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
# for rmsnorm and others, we need to check the shape
|
||||
return policy
|
||||
|
||||
|
Reference in New Issue
Block a user