mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[inference] chatglm2 infer demo (#4724)
* add chatglm2 * add * gather needed kernels * fix some bugs * finish context forward * finish context stage * fix * add * pause * add * fix bugs * finish chatglm * fix bug * change some logic * fix bugs * change some logics * add * add * add * fix * fix tests * fix
This commit is contained in:
@@ -380,12 +380,10 @@ class SelfAttention(torch.nn.Module):
|
||||
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.layer_number = max(1, layer_number)
|
||||
|
||||
self.projection_size = config.kv_channels * config.num_attention_heads
|
||||
# Per attention head and per partition values.
|
||||
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
||||
self.num_attention_heads_per_partition = config.num_attention_heads
|
||||
|
||||
self.multi_query_attention = config.multi_query_attention
|
||||
self.qkv_hidden_size = 3 * self.projection_size
|
||||
if self.multi_query_attention:
|
||||
@@ -445,7 +443,6 @@ class SelfAttention(torch.nn.Module):
|
||||
|
||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
|
||||
if self.multi_query_attention:
|
||||
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||
[
|
||||
@@ -541,7 +538,6 @@ class SelfAttention(torch.nn.Module):
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
|
||||
output = self.dense(context_layer)
|
||||
|
||||
return output, kv_cache
|
||||
|
Reference in New Issue
Block a user