mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 06:00:44 +00:00
fix
This commit is contained in:
parent
b124603c68
commit
d6f3508910
@ -21,6 +21,7 @@ from transformers.models.bloom.modeling_bloom import (
|
|||||||
BloomForSequenceClassification,
|
BloomForSequenceClassification,
|
||||||
BloomForTokenClassification,
|
BloomForTokenClassification,
|
||||||
BloomModel,
|
BloomModel,
|
||||||
|
dropout_add,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
@ -856,6 +857,92 @@ def get_jit_fused_bloom_gelu_forward():
|
|||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
# Fixed the q_length args when doing the sequence parallelism in bloom model.
|
||||||
|
def get_bloom_sequence_parallel_attention_forward(shard_config: ShardConfig):
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: BloomAttention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
alibi: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_past: Optional[Cache] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
batch_size, q_length, _ = hidden_states.shape
|
||||||
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||||
|
# 3 x [batch_size, num_heads, seq_length, head_dim]
|
||||||
|
query_layer, key_layer, value_layer = self._reshape(fused_qkv)
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
|
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# reshape qkv for further computations
|
||||||
|
query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
|
||||||
|
key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
|
||||||
|
value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
|
||||||
|
|
||||||
|
# [batch_size * num_heads, q_length, kv_length]
|
||||||
|
attention_scores = alibi.baddbmm(
|
||||||
|
batch1=query_layer,
|
||||||
|
batch2=key_layer,
|
||||||
|
beta=self.beta,
|
||||||
|
alpha=self.inv_norm_factor,
|
||||||
|
)
|
||||||
|
if shard_config.enable_sequence_parallelism:
|
||||||
|
_, q_length, _ = query_layer.shape
|
||||||
|
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||||
|
attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
|
||||||
|
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype)
|
||||||
|
|
||||||
|
# [batch_size, num_heads, q_length, kv_length]
|
||||||
|
attention_probs = self.attention_dropout(attention_probs)
|
||||||
|
|
||||||
|
if head_mask is not None:
|
||||||
|
attention_probs = attention_probs * head_mask
|
||||||
|
|
||||||
|
# change view [batch_size x num_heads, q_length, kv_length]
|
||||||
|
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
|
||||||
|
|
||||||
|
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||||
|
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
||||||
|
|
||||||
|
# change view [batch_size, q_length, num_heads * head_dim]
|
||||||
|
context_layer = self._merge_heads(context_layer)
|
||||||
|
|
||||||
|
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||||
|
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||||
|
slices = self.hidden_size / self.pretraining_tp
|
||||||
|
output_tensor = torch.zeros_like(context_layer)
|
||||||
|
for i in range(self.pretraining_tp):
|
||||||
|
output_tensor = output_tensor + F.linear(
|
||||||
|
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||||
|
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_tensor = self.dense(context_layer)
|
||||||
|
|
||||||
|
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||||
|
|
||||||
|
outputs = (output_tensor, layer_past)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attention_probs,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||||
from transformers import BloomModel
|
from transformers import BloomModel
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn
|
|||||||
from ..modeling.bloom import (
|
from ..modeling.bloom import (
|
||||||
BloomPipelineForwards,
|
BloomPipelineForwards,
|
||||||
build_bloom_alibi_tensor_fn,
|
build_bloom_alibi_tensor_fn,
|
||||||
|
get_bloom_sequence_parallel_attention_forward,
|
||||||
get_bloom_sequence_parallel_forward_fn,
|
get_bloom_sequence_parallel_forward_fn,
|
||||||
get_jit_fused_bloom_attention_forward,
|
get_jit_fused_bloom_attention_forward,
|
||||||
get_jit_fused_bloom_gelu_forward,
|
get_jit_fused_bloom_gelu_forward,
|
||||||
@ -61,6 +62,15 @@ class BloomPolicy(Policy):
|
|||||||
|
|
||||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_bloom_sequence_parallel_attention_forward(self.shard_config),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=BloomAttention,
|
||||||
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
assert (
|
assert (
|
||||||
self.model.config.n_head % self.shard_config.tensor_parallel_size == 0
|
self.model.config.n_head % self.shard_config.tensor_parallel_size == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user