mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-09 20:07:41 +00:00
fix
This commit is contained in:
parent
d6f3508910
commit
4fbbf4737a
@ -736,35 +736,24 @@ def get_jit_fused_bloom_attention_forward():
|
|||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
|
batch_size, q_length, _ = hidden_states.shape
|
||||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||||
|
# 3 x [batch_size, num_heads, seq_length, head_dim]
|
||||||
|
query_layer, key_layer, value_layer = self._reshape(fused_qkv)
|
||||||
|
|
||||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
|
||||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
|
||||||
|
|
||||||
batch_size, q_length, _, _ = query_layer.shape
|
|
||||||
|
|
||||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
|
||||||
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
|
|
||||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
# concatenate along seq_length dimension:
|
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
|
||||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
|
||||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
|
||||||
key_layer = torch.cat((past_key, key_layer), dim=2)
|
|
||||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
|
||||||
|
|
||||||
_, _, kv_length = key_layer.shape
|
# reshape qkv for further computations
|
||||||
|
query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
|
||||||
if use_cache is True:
|
key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
|
||||||
present = (key_layer, value_layer)
|
value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
|
|
||||||
# [batch_size * num_heads, q_length, kv_length]
|
# [batch_size * num_heads, q_length, kv_length]
|
||||||
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
attention_scores = alibi.baddbmm(
|
||||||
matmul_result = alibi.baddbmm(
|
|
||||||
batch1=query_layer,
|
batch1=query_layer,
|
||||||
batch2=key_layer,
|
batch2=key_layer,
|
||||||
beta=self.beta,
|
beta=self.beta,
|
||||||
@ -772,15 +761,13 @@ def get_jit_fused_bloom_attention_forward():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||||
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
|
||||||
input_dtype = attention_scores.dtype
|
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype)
|
||||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
|
||||||
if input_dtype == torch.float16:
|
|
||||||
attention_scores = attention_scores.to(torch.float)
|
|
||||||
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
|
||||||
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
|
||||||
|
|
||||||
# [batch_size, num_heads, q_length, kv_length]
|
# [batch_size, num_heads, q_length, kv_length]
|
||||||
attention_probs = self.attention_dropout(attention_probs)
|
attention_probs = self.attention_dropout(attention_probs)
|
||||||
@ -789,12 +776,12 @@ def get_jit_fused_bloom_attention_forward():
|
|||||||
attention_probs = attention_probs * head_mask
|
attention_probs = attention_probs * head_mask
|
||||||
|
|
||||||
# change view [batch_size x num_heads, q_length, kv_length]
|
# change view [batch_size x num_heads, q_length, kv_length]
|
||||||
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
|
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
|
||||||
|
|
||||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||||
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
||||||
|
|
||||||
# change view [batch_size, num_heads, q_length, head_dim]
|
# change view [batch_size, q_length, num_heads * head_dim]
|
||||||
context_layer = self._merge_heads(context_layer)
|
context_layer = self._merge_heads(context_layer)
|
||||||
|
|
||||||
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||||
@ -809,9 +796,9 @@ def get_jit_fused_bloom_attention_forward():
|
|||||||
else:
|
else:
|
||||||
output_tensor = self.dense(context_layer)
|
output_tensor = self.dense(context_layer)
|
||||||
|
|
||||||
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||||
|
|
||||||
outputs = (output_tensor, present)
|
outputs = (output_tensor, layer_past)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
outputs += (attention_probs,)
|
outputs += (attention_probs,)
|
||||||
|
|
||||||
@ -1072,6 +1059,14 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
|
||||||
|
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||||
|
hidden_states = gather_forward_split_backward(
|
||||||
|
hidden_states,
|
||||||
|
dim=1,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
fp8_communication=shard_config.fp8_communication,
|
||||||
|
)
|
||||||
|
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user