diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index ec4044127..0a9b5293d 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -236,7 +236,7 @@ Completion api is used for single sequence request, like answer a question or co - POST '/chat': Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models. #### chat-template -Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example temlate bellow. Both str or file style chat template are supported. +Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example template bellow. Both str or file style chat template are supported. ### Usage #### Args for customizing your server The configuration for api server contains both serving interface and engine backend. diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 9cf9a65e6..df6b91be1 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -169,7 +169,8 @@ class InferenceConfig(RPC_PARAM): no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences. repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0. ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. - n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. + use_spec_dec (bool): Indicate whether to use speculative decoding, defaults to False. + max_n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False. block_size (int): The number of blocks in a logical block, defaults to 16. tp_size (int): Tensor parallel size, defaults to 1. @@ -214,6 +215,7 @@ class InferenceConfig(RPC_PARAM): ignore_eos: bool = False # speculative decoding configs + use_spec_dec: bool = False max_n_spec_tokens: int = 5 glimpse_large_kv: bool = False @@ -310,6 +312,15 @@ class InferenceConfig(RPC_PARAM): meta_config[type] = getattr(model_config, type) return GenerationConfig.from_dict(meta_config) + + def to_model_inference_config(self) -> "ModelInferenceConfig": + model_inference_config = ModelInferenceConfig( + dtype=self.dtype, + use_cuda_kernel=self.use_cuda_kernel, + use_spec_dec=self.use_spec_dec, + use_cuda_graph=self.use_cuda_graph, + ) + return model_inference_config def to_rpc_param(self) -> dict: kwargs = { @@ -362,3 +373,22 @@ class InferenceConfig(RPC_PARAM): # Set the attributes from the parsed arguments. inference_config = cls(**inference_config_args) return inference_config + +@dataclass +class ModelInferenceConfig(): + """ + Configurations used when initializing/sharding model for inference. + + Args: + dtype (torch.dtype): The data type for weights and activations. + use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally + use_spec_dec (bool): Indicate whether to use speculative decoding. + use_flash_attn (bool): Indicate whether to use flash attention. + use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. + """ + dtype: torch.dtype = None + use_cuda_kernel: bool = False + use_spec_dec: bool = False + use_flash_attn: bool = False + use_cuda_graph: bool = False + \ No newline at end of file diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 1b6e62553..7c78223d2 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.inference.batch_bucket import BatchBucket -from colossalai.inference.config import InferenceConfig, InputMetaData +from colossalai.inference.config import InferenceConfig, InputMetaData, ModelInferenceConfig from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.sampler import search_tokens @@ -72,8 +72,9 @@ class InferenceEngine: self.verbose = verbose self.logger = get_dist_logger(__name__) + self.model_inference_config = inference_config.to_model_inference_config() - self.init_model(model_or_path, model_policy) + self.init_model(model_or_path, model_policy, self.model_inference_config) self.generation_config = inference_config.to_generation_config(self.model_config) self.generation_config_dict = self.generation_config.to_dict() @@ -97,7 +98,10 @@ class InferenceEngine: self.capture_model(self.k_cache, self.v_cache) # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` - self.use_spec_dec = False + self.use_spec_dec = self.inference_config.use_spec_dec + + # TODO: when use_spec_dec set to True, users should pass drafter_model configs into InferenceEngine + # We can add a SpecDecConfig class to store these configs. self.drafter_model = None self.drafter = None self.use_glide = False @@ -105,13 +109,19 @@ class InferenceEngine: self._verify_args() - def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None): + def init_model( + self, + model_or_path: Union[nn.Module, str], + model_policy: Union[Policy, Type[Policy]] = None, + model_inference_config: ModelInferenceConfig = None, + ): """ Shard model or/and Load weight Args: model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. - model_policy (Policy): the policy to replace the model + model_policy (Policy): the policy to replace the model. + model_inference_config: the configuration for modeling initialization when inference. """ if isinstance(model_or_path, str): @@ -124,7 +134,8 @@ class InferenceEngine: # the model load process in the future. model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True) else: - raise ValueError(f"Model {arch} is not supported.") + # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate + raise ValueError(f"Model {arch} is not supported.") except Exception as e: self.logger.error( @@ -167,6 +178,7 @@ class InferenceEngine: self.model = self._shardformer( model, model_policy, + model_inference_config, None, tp_group=tp_group, ) @@ -187,7 +199,7 @@ class InferenceEngine: # assert if_has_index_file, "the model path is invalid" # cpt_io.load_model(self.model, model_index_file) - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + free_gpu_memory, _ = torch.cuda.mem_get_info() peak_memory = init_gpu_memory - free_gpu_memory if self.verbose: self.logger.info( @@ -287,6 +299,7 @@ class InferenceEngine: self, model: nn.Module, model_policy: Policy, + model_inference_config: ModelInferenceConfig, stage_manager: PipelineStageManager = None, tp_group: ProcessGroupMesh = None, ) -> nn.Module: @@ -348,6 +361,8 @@ class InferenceEngine: engine.clear_spec_dec() ``` """ + self.logger.warning(f"Current method will be deprecated soon. To use speculative decoding, please set `use_spec_dec` in `InferenceConfig` instead.") + if drafter_model is None and self.drafter is None: raise ValueError("Drafter not initialized. Please provide a Drafter Model") if n_spec_tokens is not None: @@ -517,19 +532,19 @@ class InferenceEngine: prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, return_token_ids: bool = False, generation_config: Optional[GenerationConfig] = None, - ) -> List[str]: + ) -> Union[List[str], Tuple[List[str], List[List[int]]]]: """ Executing the inference step. Args: - prompts (Union[List[str], optional): Input prompts. Defaults to None. - prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. request_ids (List[int], optional): The request ID. Defaults to None. - return_token_ids (bool): Whether to return output token ids. Defaults to False. - generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None. + return_token_ids (bool, optional): Whether to return output token ids. Defaults to False. + generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None. Returns: - List[str]: Inference result returned by one generation. + Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation. """ gen_config_dict = generation_config.to_dict() if generation_config is not None else {} diff --git a/colossalai/inference/modeling/backends/__init__.py b/colossalai/inference/modeling/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/inference/modeling/backends/attention_backend.py b/colossalai/inference/modeling/backends/attention_backend.py new file mode 100644 index 000000000..4d8216131 --- /dev/null +++ b/colossalai/inference/modeling/backends/attention_backend.py @@ -0,0 +1,146 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from flash_attn import flash_attn_varlen_func +import torch + +from colossalai.inference.config import InputMetaData +from colossalai.inference.utils import can_use_flash_attn2 +from colossalai.logging import get_dist_logger +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import ( + context_attention_unpadded, + flash_decoding_attention, +) + +logger = get_dist_logger(__name__) +inference_ops = InferenceOpsLoader().load() + +@dataclass +class AttentionMetaData: + query_states: torch.Tensor + key_states: torch.Tensor + value_states: torch.Tensor + k_cache: torch.Tensor + v_cache: torch.Tensor + block_tables: torch.Tensor + block_size: int + kv_seq_len: int = None + sequence_lengths: torch.Tensor = None + cu_seqlens: torch.Tensor = None + sm_scale: int = None + alibi_slopes: torch.Tensor = None + output_tensor: torch.Tensor = None + use_spec_dec: bool = False + use_alibi_attn: bool = False + + +class AttentionBackend(ABC): + @abstractmethod + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + raise NotImplementedError + + @abstractmethod + def decode(self, attn_metadatas: AttentionMetaData, **kwargs): + raise NotImplementedError + + +class CudaAttentionBackend(AttentionBackend): + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + if not attn_metadata.use_spec_dec: + token_nums = kwargs.get('token_nums', -1) + + attn_output = flash_attn_varlen_func( + attn_metadata.query_states, + attn_metadata.key_states, + attn_metadata.value_states, + cu_seqlens_q=attn_metadata.cu_seqlens, + cu_seqlens_k=attn_metadata.cu_seqlens, + max_seqlen_k=attn_metadata.kv_seq_len, + max_seqlen_v=attn_metadata.kv_seq_len, + dropout_p=0.0, + softmax_scale=attn_metadata.sm_scale, + causal=True, + ) + attn_output = attn_output.view(token_nums, -1) + else: + attn_output = context_attention_unpadded( + q=attn_metadata.query_states, + k=attn_metadata.key_states, + v=attn_metadata.value_states, + k_cache=attn_metadata.k_cache, + v_cache=attn_metadata.v_cache, + context_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + block_size=attn_metadata.block_size, + output=attn_metadata.output_tensor, + max_seq_len=attn_metadata.kv_seq_len, + sm_scale=attn_metadata.sm_scale, + use_new_kcache_layout=True, + ) + return attn_output + + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + fd_inter_tensor = kwargs.get('fd_inter_tensor', None) + output_tensor = attn_metadata.output_tensor + inference_ops.flash_decoding_attention( + output_tensor, + attn_metadata.query_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.block_tables, + attn_metadata.block_size, + attn_metadata.kv_seq_len, + fd_inter_tensor.mid_output, + fd_inter_tensor.exp_sums, + fd_inter_tensor.max_logits, + attn_metadata.alibi_slopes, + attn_metadata.sm_scale, + ) + return output_tensor + + +class TritonAttentionBackend(AttentionBackend): + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + return context_attention_unpadded( + q=attn_metadata.query_states, + k=attn_metadata.key_states, + v=attn_metadata.value_states, + k_cache=attn_metadata.k_cache, + v_cache=attn_metadata.v_cache, + context_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + block_size=attn_metadata.block_size, + output=attn_metadata.output_tensor, + max_seq_len=attn_metadata.kv_seq_len, + sm_scale=attn_metadata.sm_scale, + use_new_kcache_layout=False, + ) + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + fd_inter_tensor = kwargs.get('fd_inter_tensor', None) + return flash_decoding_attention( + q=attn_metadata.query_states, + k_cache=attn_metadata.k_cache, + v_cache=attn_metadata.v_cache, + kv_seq_len=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + block_size=attn_metadata.block_size, + max_seq_len_in_batch=attn_metadata.kv_seq_len, + output=attn_metadata.output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + sm_scale=attn_metadata.sm_scale, + kv_group_num=kwargs.get('num_key_value_groups', 0), + q_len=kwargs.get('q_len', 1), + ) + + +def get_attention_backend(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionBackend: + use_flash_attn = can_use_flash_attn2(dtype) + if use_cuda_kernel and use_flash_attn and not use_spec_dec: + return CudaAttentionBackend() + else: + return TritonAttentionBackend() + \ No newline at end of file diff --git a/colossalai/inference/modeling/backends/attention_context.py b/colossalai/inference/modeling/backends/attention_context.py new file mode 100644 index 000000000..909660121 --- /dev/null +++ b/colossalai/inference/modeling/backends/attention_context.py @@ -0,0 +1,134 @@ +from abc import ABC, abstractmethod +import torch + +from colossalai.inference.utils import can_use_flash_attn2 +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData +from colossalai.logging import get_dist_logger +from colossalai.kernel.triton import ( + copy_k_to_blocked_cache, + decoding_fused_rotary_embedding, + rotary_embedding, +) + +logger = get_dist_logger(__name__) +inference_ops = InferenceOpsLoader().load() + + +class AttentionContext(ABC): + @abstractmethod + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + raise NotImplementedError + + @abstractmethod + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + raise NotImplementedError + + +class CudaAttentionContext(AttentionContext): + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + if not attn_metadata.use_spec_dec: + if not attn_metadata.use_alibi_attn: + inference_ops.rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get('cos', None), + kwargs.get('sin', None), + kwargs.get('high_precision', False), + ) + inference_ops.context_kv_cache_memcpy( + attn_metadata.key_states, + attn_metadata.value_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.cu_seqlens, + attn_metadata.block_tables, + attn_metadata.kv_seq_len, + ) + else: + rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get('cos', None), + kwargs.get('sin', None), + ) + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + if attn_metadata.use_alibi_attn: + inference_ops.rotary_embedding_and_cache_copy( + attn_metadata.query_states, + attn_metadata.key_states, + attn_metadata.value_states, + kwargs.get('cos', None), + kwargs.get('sin', None), + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.block_tables, + attn_metadata.high_precision, + ) + else: + inference_ops.decode_kv_cache_memcpy( + attn_metadata.key_states, + attn_metadata.value_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.block_tables, + ) + + +class TritonAttentionContext(AttentionContext): + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + if not attn_metadata.use_alibi_attn: + rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get('cos', None), + kwargs.get('sin', None), + ) + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn: + decoding_fused_rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + attn_metadata.value_states, + kwargs.get('cos', None), + kwargs.get('sin', None), + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.block_tables, + attn_metadata.sequence_lengths, + ) + else: + if not attn_metadata.use_alibi_attn: + rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get('cos', None), + kwargs.get('sin', None), + ) + copy_k_to_blocked_cache( + attn_metadata.key_states, + attn_metadata.k_cache, + kv_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + n=kwargs.get('q_len', 1) + ) + copy_k_to_blocked_cache( + attn_metadata.value_states, + attn_metadata.v_cache, + kv_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + n=kwargs.get('q_len', 1) + ) + + +def get_attention_context(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionContext: + use_flash_attn = can_use_flash_attn2(dtype) + if use_cuda_kernel and use_flash_attn and not use_spec_dec: + return CudaAttentionContext() + else: + return TritonAttentionContext() \ No newline at end of file diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index b50e73d6f..920de0d8a 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -8,6 +8,7 @@ import torch.nn as nn from torch.distributed import ProcessGroup from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.utils import get_alibi_slopes from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import ( @@ -47,22 +48,6 @@ inference_ops = InferenceOpsLoader().load() logger = get_dist_logger(__name__) -# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57 -def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: - closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) - base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device) - powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device) - slopes = torch.pow(base, powers) - if closest_power_of_2 != num_heads: - extra_base = torch.tensor( - 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device - ) - num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device) - slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - return slopes - - def baichuan_rmsnorm_forward( self, hidden_states: torch.Tensor, diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index f6f160eb7..459c6e040 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -18,6 +18,9 @@ from transformers.models.llama.modeling_llama import ( from colossalai.inference.config import InputMetaData from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.modeling.backends.attention_backend import get_attention_backend, AttentionMetaData +from colossalai.inference.modeling.backends.attention_context import get_attention_context +from colossalai.inference.utils import can_use_flash_attn2 from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import ( context_attention_unpadded, @@ -36,14 +39,6 @@ inference_ops = InferenceOpsLoader().load() logger = get_dist_logger(__name__) -try: - from flash_attn import flash_attn_varlen_func - - use_flash_attn2 = True -except ImportError: - use_flash_attn2 = False - logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") - def llama_causal_lm_forward( self: LlamaForCausalLM, @@ -126,7 +121,7 @@ def llama_model_forward( cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes]) elif use_cuda_kernel: - if inputmetadata.dtype != torch.float32 and use_flash_attn2: + if inputmetadata.dtype != torch.float32 and can_use_flash_attn2(): cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) hidden_dim = self._cos_cached.size(-1) @@ -532,112 +527,54 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): ) block_size = k_cache.size(-2) - - if is_prompts: - if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: - # flash attn 2 currently only supports FP16/BF16. - inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) - inference_ops.context_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len - ) - - attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=kv_seq_len, - max_seqlen_k=kv_seq_len, - dropout_p=0.0, - softmax_scale=sm_scale, - causal=True, - ) - attn_output = attn_output.view(token_nums, -1) - else: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - attn_output = context_attention_unpadded( - q=query_states, - k=key_states, - v=value_states, - k_cache=k_cache, - v_cache=v_cache, - context_lengths=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - output=output_tensor, - max_seq_len=kv_seq_len, - sm_scale=sm_scale, - use_new_kcache_layout=use_cuda_kernel, - ) - else: + + attn_metadata = AttentionMetaData( + query_states=query_states, + key_states=key_states, + value_states=value_states, + k_cache=k_cache, + v_cache=v_cache, + block_tables=block_tables, + block_size=block_size, + kv_seq_len=kv_seq_len, + sequence_lengths=sequence_lengths, + sm_scale=sm_scale, + alibi_slopes=None, + cu_seqlens=cu_seqlens, + output_tensor=output_tensor, + use_spec_dec=is_verifier, + use_alibi_attn=False, + ) + + attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype) + attention_context = get_attention_context(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype) + + if is_prompts: # prefilling stage + attention_context.prefill( + attn_metadata, + cos=cos_sin[0], + sin=cos_sin[1], + high_precision=high_precision, + ) + attn_output = attention_backend.prefill( + attn_metadata, + token_nums=token_nums, + ) + else: # decoding stage q_len = tokens_to_verify + 1 if is_verifier else 1 - - if use_cuda_kernel: - inference_ops.rotary_embedding_and_cache_copy( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - sequence_lengths, - block_tables, - high_precision, - ) - inference_ops.flash_decoding_attention( - output_tensor, - query_states, - k_cache, - v_cache, - sequence_lengths, - block_tables, - block_size, - kv_seq_len, - fd_inter_tensor.mid_output, - fd_inter_tensor.exp_sums, - fd_inter_tensor.max_logits, - None, - sm_scale, - ) - attn_output = output_tensor - else: - if is_verifier: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - copy_k_to_blocked_cache( - key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len - ) - copy_k_to_blocked_cache( - value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len - ) - else: - decoding_fused_rotary_embedding( - query_states, - key_states, - value_states, - cos_sin[0], - cos_sin[1], - k_cache, - v_cache, - block_tables, - sequence_lengths, - ) - attn_output = flash_decoding_attention( - q=query_states, - k_cache=k_cache, - v_cache=v_cache, - kv_seq_len=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - max_seq_len_in_batch=kv_seq_len, - output=output_tensor, - mid_output=fd_inter_tensor.mid_output, - mid_output_lse=fd_inter_tensor.mid_output_lse, - sm_scale=sm_scale, - kv_group_num=self.num_key_value_groups, - q_len=q_len, - ) + + attention_context.decode( + attn_metadata, + cos=cos_sin[0], + sin=cos_sin[1], + q_len=q_len, + ) + attn_output = attention_backend.decode( + attn_metadata, + fd_inter_tensor=fd_inter_tensor, + num_key_value_groups=self.num_key_value_groups, + q_len=q_len, + ) attn_output = attn_output.view(-1, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -695,3 +632,4 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): def extra_repr(self) -> str: return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False" + \ No newline at end of file diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 072bedec3..c3f5b4940 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -3,6 +3,7 @@ Utils for model inference """ import os import re +import math from pathlib import Path from typing import Optional, Tuple @@ -10,6 +11,9 @@ import torch from torch import nn from colossalai.testing import free_port +from colossalai.logging import get_dist_logger + +logger = get_dist_logger(__name__) def init_to_get_rotary(self, base=10000, use_elem=False): @@ -113,3 +117,45 @@ def find_available_ports(num: int): print(f"An OS error occurred: {e}") raise RuntimeError("Error finding available ports") return free_ports + + +def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: + """ + Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57 + + Args: + num_heads (int): The number of attention heads. + device (torch.device): The device to use. + + Returns: + torch.Tensor: The Alibi slopes. + """ + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device) + slopes = torch.pow(base, powers) + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +def can_use_flash_attn2(dtype: torch.dtype) -> bool: + """ + Check flash attention2 availability. + """ + if dtype not in (torch.float16, torch.bfloat16): + logger.warning(f"Flash attn2 currently only supports float16 and bfloat16.") + return False + + try: + from flash_attn import __version__ + logger.info(f"flash_attn2 version {__version__}.") + return True + except ImportError: + logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") + return False \ No newline at end of file