[llama] add flash attn patch for npu (#5362)

This commit is contained in:
Hongxin Liu 2024-02-05 16:48:34 +08:00 committed by GitHub
parent 73f9f23fc6
commit a4cec1715b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,15 +1,15 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import math
from types import MethodType from types import MethodType
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from transformers.models.llama.configuration_llama import LlamaConfig
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
from flash_attn.ops.rms_norm import rms_norm
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaAttention, LlamaAttention,
LlamaForCausalLM, LlamaForCausalLM,
@ -19,194 +19,334 @@ from transformers.models.llama.modeling_llama import (
repeat_kv, repeat_kv,
) )
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
logger = get_dist_logger() logger = get_dist_logger()
if get_accelerator().name == "cuda":
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
from flash_attn.ops.rms_norm import rms_norm
def _prepare_decoder_attention_mask( def _prepare_decoder_attention_mask(
self: LlamaModel, self: LlamaModel,
attention_mask: torch.BoolTensor, attention_mask: torch.BoolTensor,
input_shape: torch.Size, input_shape: torch.Size,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
past_key_values_length: int, past_key_values_length: int,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Decoder attetion mask Decoder attetion mask
""" """
if past_key_values_length > 0 and attention_mask is not None: if past_key_values_length > 0 and attention_mask is not None:
attention_mask = torch.cat( attention_mask = torch.cat(
tensors=( tensors=(
torch.full( torch.full(
size=(input_shape[0], past_key_values_length), size=(input_shape[0], past_key_values_length),
fill_value=True, fill_value=True,
dtype=attention_mask.dtype, dtype=attention_mask.dtype,
device=attention_mask.device, device=attention_mask.device,
),
attention_mask,
), ),
attention_mask,
),
dim=-1,
) # (bsz, past_key_values_length + q_len)
if attention_mask is not None and torch.all(attention_mask):
return None # Faster
return attention_mask
def attention_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
"""
if output_attentions:
logger.warning(
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
"return `None` instead."
)
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
q_slicing, kv_slicing = (
dim // self.config.pretraining_tp
for dim in (
self.num_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
)
) # `Tuple[int, int]`
q_slices, k_slices, v_slices = (
proj.weight.split(slicing, dim=0)
for proj, slicing in (
(self.q_proj, q_slicing),
(self.k_proj, kv_slicing),
(self.v_proj, kv_slicing),
)
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
q, k, v = (
torch.cat(
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
dim=-1, dim=-1,
) # (bsz, past_key_values_length + q_len)
if attention_mask is not None and torch.all(attention_mask):
return None # Faster
return attention_mask
def attention_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
"""
if output_attentions:
logger.warning(
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
"return `None` instead."
)
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
q_slicing, kv_slicing = (
dim // self.config.pretraining_tp
for dim in (
self.num_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
)
) # `Tuple[int, int]`
q_slices, k_slices, v_slices = (
proj.weight.split(slicing, dim=0)
for proj, slicing in (
(self.q_proj, q_slicing),
(self.k_proj, kv_slicing),
(self.v_proj, kv_slicing),
)
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
q, k, v = (
torch.cat(
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
dim=-1,
)
for slices in (q_slices, k_slices, v_slices)
)
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
else:
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
q, k, v = (
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
for states, num_heads in (
(q, self.num_heads),
(k, self.num_key_value_heads),
(v, self.num_key_value_heads),
) )
for slices in (q_slices, k_slices, v_slices)
) )
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
# (bsz, q_len, num_heads * head_dim), past_kv_len = 0
# (bsz, q_len, num_key_value_heads * head_dim), if past_key_value is not None:
# (bsz, q_len, num_key_value_heads * head_dim) # if `past_key_value` is not None, `kv_len` > `q_len`.
else: past_kv_len = past_key_value[0].shape[-2]
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) kv_len += past_kv_len
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); cos, sin = self.rotary_emb(v, seq_len=kv_len)
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
q, k, v = ( q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) if past_key_value is not None:
for states, num_heads in ( # reuse k, v, self_attention
(q, self.num_heads), k = torch.cat([past_key_value[0], k], dim=2)
(k, self.num_key_value_heads), v = torch.cat([past_key_value[1], v], dim=2)
(v, self.num_key_value_heads),
)
)
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
past_kv_len = 0
if past_key_value is not None:
# if `past_key_value` is not None, `kv_len` > `q_len`.
past_kv_len = past_key_value[0].shape[-2]
kv_len += past_kv_len
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) past_key_value = (k, v) if use_cache else None
cos, sin = self.rotary_emb(v, seq_len=kv_len)
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
past_key_value = (k, v) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
# repeat k/v heads if n_kv_heads < n_heads key_padding_mask = attention_mask
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) q, k, v = (states.transpose(1, 2) for states in (q, k, v))
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
key_padding_mask = attention_mask if past_kv_len > 0:
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) q = torch.cat(
q, k, v = (states.transpose(1, 2) for states in (q, k, v)) tensors=(
torch.full(
if past_kv_len > 0: size=(bsz, past_kv_len, self.num_heads, self.head_dim),
q = torch.cat( fill_value=0.0,
tensors=( dtype=q.dtype,
torch.full( device=q.device,
size=(bsz, past_kv_len, self.num_heads, self.head_dim), ),
fill_value=0.0, q,
dtype=q.dtype,
device=q.device,
), ),
q, dim=1,
), ) # (bsz, past_kv_len + q_len, num_heads, head_dim)
dim=1,
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
if key_padding_mask is None: if key_padding_mask is None:
# (bsz, past_kv_len + q_len, num_heads, head_dim) # (bsz, past_kv_len + q_len, num_heads, head_dim)
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim) output = rearrange(
else: output, pattern="... h d -> ... (h d)"
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) ) # (bsz, past_kv_len + q_len, num_heads * head_dim)
kv, _, cu_kv_lens, max_kv_len = unpad_input( else:
hidden_states=torch.stack(tensors=(k, v), dim=2), q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
attention_mask=key_padding_mask, kv, _, cu_kv_lens, max_kv_len = unpad_input(
) hidden_states=torch.stack(tensors=(k, v), dim=2),
output_unpad = flash_attn_varlen_kvpacked_func( attention_mask=key_padding_mask,
q=q, )
kv=kv, output_unpad = flash_attn_varlen_kvpacked_func(
cu_seqlens_q=cu_q_lens, q=q,
cu_seqlens_k=cu_kv_lens, kv=kv,
max_seqlen_q=max_q_len, cu_seqlens_q=cu_q_lens,
max_seqlen_k=max_kv_len, cu_seqlens_k=cu_kv_lens,
dropout_p=0.0, max_seqlen_q=max_q_len,
softmax_scale=None, max_seqlen_k=max_kv_len,
causal=True, dropout_p=0.0,
) softmax_scale=None,
output = pad_input( causal=True,
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), )
indices=indices, output = pad_input(
batch=bsz, hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
seqlen=past_kv_len + q_len, indices=indices,
) # (bsz, past_kv_len + q_len, num_heads * head_dim) batch=bsz,
seqlen=past_kv_len + q_len,
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
if past_kv_len > 0: if past_kv_len > 0:
# Strip off the zero query outputs. # Strip off the zero query outputs.
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
output = self.o_proj(output) # (bsz, q_len, hidden_size) output = self.o_proj(output) # (bsz, q_len, hidden_size)
return output, None, past_key_value return output, None, past_key_value
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Formard function for RMS Norm
"""
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
""" for name, module in model.named_modules():
Formard function for RMS Norm if isinstance(module, LlamaAttention):
""" module.forward = MethodType(attention_forward, module)
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) if isinstance(module, LlamaModel):
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
if isinstance(module, LlamaRMSNorm):
module.forward = MethodType(rms_norm_forward, module)
elif get_accelerator().name == "npu":
import torch_npu
def replace_with_flash_attention(model: LlamaForCausalLM) -> None: class NPULlamaAttention(LlamaAttention):
for name, module in model.named_modules(): use_flash: bool = True
if isinstance(module, LlamaAttention):
module.forward = MethodType(attention_forward, module) def __init__(self, config: LlamaConfig):
if isinstance(module, LlamaModel): super().__init__(config)
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) self.setup()
if isinstance(module, LlamaRMSNorm):
module.forward = MethodType(rms_norm_forward, module) def setup(self):
self._softmax_scale = 1 / math.sqrt(self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if not self.use_flash:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
else:
attn_output, *_ = torch_npu.npu_fusion_attention(
query_states,
key_states,
value_states,
self.num_heads,
"BNSD",
atten_mask=attention_mask.bool(),
scale=self._softmax_scale,
padding_mask=None,
pre_tockens=65535,
next_tockens=0,
keep_prob=1.0,
inner_precise=0,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum(
[F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
)
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class NPURMSNorm(LlamaRMSNorm):
def forward(self, hidden_states):
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
module.__class__ = NPULlamaAttention
module.setup()
if isinstance(module, LlamaRMSNorm):
module.__class__ = NPURMSNorm