Fix tests and naming

Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
char-1ee
2024-06-03 05:41:32 +00:00
parent 04386d9eff
commit eec77e5702
5 changed files with 154 additions and 250 deletions

View File

@@ -10,6 +10,8 @@ from torch.distributed import ProcessGroup
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import get_alibi_slopes
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
@@ -23,28 +25,8 @@ from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
logger = get_dist_logger(__name__)
try:
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
logger = get_dist_logger(__name__)
try:
from flash_attn import flash_attn_varlen_func
use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)
@@ -251,122 +233,54 @@ class NopadBaichuanAttention(ParallelModule):
)
block_size = k_cache.size(-2)
if is_prompts:
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
# flash attn 2 currently only supports FP16/BF16.
if not self.use_alibi_attn:
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
)
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=kv_seq_len,
max_seqlen_k=kv_seq_len,
dropout_p=0.0,
softmax_scale=sm_scale,
causal=True,
alibi_slopes=self.alibi_slopes,
)
attn_output = attn_output.view(token_nums, -1)
else:
if not self.use_alibi_attn:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
alibi_slopes=self.alibi_slopes,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
)
else:
attn_metadata = AttentionMetaData(
query_states=query_states,
key_states=key_states,
value_states=value_states,
k_cache=k_cache,
v_cache=v_cache,
block_tables=block_tables,
block_size=block_size,
kv_seq_len=kv_seq_len,
sequence_lengths=sequence_lengths,
sm_scale=sm_scale,
alibi_slopes=self.alibi_slopes,
cu_seqlens=cu_seqlens,
output_tensor=output_tensor,
use_spec_dec=is_verifier,
use_alibi_attn=self.use_alibi_attn,
use_cuda_kernel=use_cuda_kernel,
)
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
if is_prompts: # prefilling stage
pre_attention_backend.prefill(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
high_precision=high_precision,
)
attn_output = attention_backend.prefill(
attn_metadata,
token_nums=token_nums,
)
else: # decoding stage
q_len = tokens_to_verify + 1 if is_verifier else 1
if use_cuda_kernel:
if not self.use_alibi_attn:
inference_ops.rotary_embedding_and_cache_copy(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
sequence_lengths,
block_tables,
high_precision,
)
else:
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
self.alibi_slopes,
sm_scale,
)
attn_output = output_tensor
else:
if not is_verifier and not self.use_alibi_attn:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
else:
if not self.use_alibi_attn:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache(
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
copy_k_to_blocked_cache(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)
pre_attention_backend.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
attn_output = attention_backend.decode(
attn_metadata,
fd_inter_tensor=fd_inter_tensor,
q_len=q_len,
)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = self.o_proj(attn_output)