mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[shardformer]delete xformers (#5859)
* delete xformers * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -714,93 +714,6 @@ class BloomPipelineForwards:
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_bloom_flash_attention_forward(enable_jit_fused=False):
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
def forward(
|
||||
self: BloomAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states)
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
batch_size, tgt_len, _, _ = query_layer.size()
|
||||
|
||||
_, kv_length, _, _ = key_layer.size()
|
||||
|
||||
proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
|
||||
query_layer = query_layer.contiguous().view(*proj_shape)
|
||||
key_layer = key_layer.contiguous().view(*proj_shape)
|
||||
value_layer = value_layer.contiguous().view(*proj_shape)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
tgt_len = key_layer.size()[1]
|
||||
|
||||
attention_numerical_mask = torch.zeros(
|
||||
(batch_size, self.num_heads, tgt_len, kv_length),
|
||||
dtype=torch.float32,
|
||||
device=query_layer.device,
|
||||
requires_grad=True,
|
||||
)
|
||||
attention_numerical_mask = (
|
||||
attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
|
||||
)
|
||||
attention_numerical_mask = torch.masked_fill(
|
||||
attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min
|
||||
)
|
||||
attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype)
|
||||
|
||||
context_layer = me_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_bias=attention_numerical_mask,
|
||||
scale=self.inv_norm_factor,
|
||||
p=self.attention_dropout.p,
|
||||
)
|
||||
context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
slices = self.hidden_size / self.pretraining_tp
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + F.linear(
|
||||
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
# TODO to replace with the bias_dropout_add function in jit
|
||||
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
outputs = (output_tensor, present, None)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bloom_attention_forward():
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
|
Reference in New Issue
Block a user