[inference] chatglm2 infer demo (#4724)

* add chatglm2

* add

* gather needed kernels

* fix some bugs

* finish context forward

* finish context stage

* fix

* add

* pause

* add

* fix bugs

* finish chatglm

* fix bug

* change some logic

* fix bugs

* change some logics

* add

* add

* add

* fix

* fix tests

* fix
This commit is contained in:
Jianghai
2023-09-22 11:12:50 +08:00
committed by GitHub
parent 946ab56c48
commit ce7ade3882
15 changed files with 1692 additions and 14 deletions

View File

@@ -380,12 +380,10 @@ class SelfAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
super(SelfAttention, self).__init__()
self.layer_number = max(1, layer_number)
self.projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
self.num_attention_heads_per_partition = config.num_attention_heads
self.multi_query_attention = config.multi_query_attention
self.qkv_hidden_size = 3 * self.projection_size
if self.multi_query_attention:
@@ -445,7 +443,6 @@ class SelfAttention(torch.nn.Module):
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states)
if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
@@ -541,7 +538,6 @@ class SelfAttention(torch.nn.Module):
# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)
return output, kv_cache