diff --git a/colossalai/_C/.nfs0000000013155a3b0000021b b/colossalai/_C/.nfs0000000013155a3b0000021b new file mode 100755 index 000000000..b8ad7b2ee Binary files /dev/null and b/colossalai/_C/.nfs0000000013155a3b0000021b differ diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index df6b91be1..c73ee9df4 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -10,6 +10,7 @@ import torch from transformers.generation import GenerationConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.utils import can_use_flash_attn2 GibiByte = 1024**3 @@ -312,13 +313,14 @@ class InferenceConfig(RPC_PARAM): meta_config[type] = getattr(model_config, type) return GenerationConfig.from_dict(meta_config) - - def to_model_inference_config(self) -> "ModelInferenceConfig": - model_inference_config = ModelInferenceConfig( + + def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig": + use_flash_attn = can_use_flash_attn2(self.dtype) + model_inference_config = ModelShardInferenceConfig( dtype=self.dtype, use_cuda_kernel=self.use_cuda_kernel, use_spec_dec=self.use_spec_dec, - use_cuda_graph=self.use_cuda_graph, + use_flash_attn=use_flash_attn, ) return model_inference_config @@ -374,21 +376,20 @@ class InferenceConfig(RPC_PARAM): inference_config = cls(**inference_config_args) return inference_config + @dataclass -class ModelInferenceConfig(): +class ModelShardInferenceConfig: """ - Configurations used when initializing/sharding model for inference. - + Configurations used during init of module for inference modeling. + Args: dtype (torch.dtype): The data type for weights and activations. use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally use_spec_dec (bool): Indicate whether to use speculative decoding. use_flash_attn (bool): Indicate whether to use flash attention. - use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. """ + dtype: torch.dtype = None use_cuda_kernel: bool = False use_spec_dec: bool = False use_flash_attn: bool = False - use_cuda_graph: bool = False - \ No newline at end of file diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index ae7184b81..d0d46d81b 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.inference.batch_bucket import BatchBucket -from colossalai.inference.config import InferenceConfig, InputMetaData +from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig from colossalai.inference.graph_runner import CUDAGraphRunner from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.sampler import search_tokens @@ -72,8 +72,9 @@ class InferenceEngine: self.verbose = verbose self.logger = get_dist_logger(__name__) + self.model_shard_infer_config = inference_config.to_model_shard_inference_config() - self.init_model(model_or_path, model_policy) + self.init_model(model_or_path, model_policy, self.model_shard_infer_config) self.generation_config = inference_config.to_generation_config(self.model_config) self.generation_config_dict = self.generation_config.to_dict() @@ -98,9 +99,7 @@ class InferenceEngine: # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` self.use_spec_dec = self.inference_config.use_spec_dec - - # TODO: when use_spec_dec set to True, users should pass drafter_model configs into InferenceEngine - # We can add a SpecDecConfig class to store these configs. + self.drafter_model = None self.drafter = None self.use_glide = False @@ -109,9 +108,10 @@ class InferenceEngine: self._verify_args() def init_model( - self, - model_or_path: Union[nn.Module, str], + self, + model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None, + model_shard_infer_config: ModelShardInferenceConfig = None, ): """ Shard model or/and Load weight @@ -120,6 +120,7 @@ class InferenceEngine: model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. model_policy (Policy): the policy to replace the model. model_inference_config: the configuration for modeling initialization when inference. + model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. """ if isinstance(model_or_path, str): @@ -133,7 +134,7 @@ class InferenceEngine: model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True) else: # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate - raise ValueError(f"Model {arch} is not supported.") + raise ValueError(f"Model {arch} is not supported.") except Exception as e: self.logger.error( @@ -176,6 +177,7 @@ class InferenceEngine: self.model = self._shardformer( model, model_policy, + model_shard_infer_config, None, tp_group=tp_group, ) @@ -296,6 +298,7 @@ class InferenceEngine: self, model: nn.Module, model_policy: Policy, + model_shard_infer_config: ModelShardInferenceConfig = None, stage_manager: PipelineStageManager = None, tp_group: ProcessGroupMesh = None, ) -> nn.Module: @@ -321,6 +324,7 @@ class InferenceEngine: enable_flash_attention=False, enable_jit_fused=False, enable_sequence_parallelism=False, + extra_kwargs={"model_shard_infer_config": model_shard_infer_config}, ) shardformer = ShardFormer(shard_config=shardconfig) shard_model, _ = shardformer.optimize(model, model_policy) @@ -357,8 +361,7 @@ class InferenceEngine: engine.clear_spec_dec() ``` """ - self.logger.warning(f"Current method will be deprecated soon. To use speculative decoding, please set `use_spec_dec` in `InferenceConfig` instead.") - + if drafter_model is None and self.drafter is None: raise ValueError("Drafter not initialized. Please provide a Drafter Model") if n_spec_tokens is not None: diff --git a/colossalai/inference/modeling/backends/attention_backend.py b/colossalai/inference/modeling/backends/attention_backend.py index ecdd9d4c4..ed0ccda8a 100644 --- a/colossalai/inference/modeling/backends/attention_backend.py +++ b/colossalai/inference/modeling/backends/attention_backend.py @@ -1,19 +1,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from flash_attn import flash_attn_varlen_func + import torch +from flash_attn import flash_attn_varlen_func -from colossalai.inference.config import InputMetaData -from colossalai.inference.utils import can_use_flash_attn2 -from colossalai.logging import get_dist_logger +from colossalai.inference.config import ModelShardInferenceConfig from colossalai.kernel.kernel_loader import InferenceOpsLoader -from colossalai.kernel.triton import ( - context_attention_unpadded, - flash_decoding_attention, -) - -logger = get_dist_logger(__name__) -inference_ops = InferenceOpsLoader().load() +from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention @dataclass @@ -33,7 +26,6 @@ class AttentionMetaData: output_tensor: torch.Tensor = None use_spec_dec: bool = False use_alibi_attn: bool = False - use_cuda_kernel: bool = False class AttentionBackend(ABC): @@ -46,7 +38,16 @@ class AttentionBackend(ABC): raise NotImplementedError -class CudaAttentionBackend(AttentionBackend): +class FlashAttentionBackend(AttentionBackend): + """ + Attention backend when use_cuda_kernel is True and flash-attn is installed. It uses + `flash_attn_varlen_func` for prefilling and our cuda op `flash_decoding_attention` for decoding. + """ + + def __init__(self): + super().__init__() + self.inference_ops = InferenceOpsLoader().load() + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): token_nums = kwargs.get("token_nums", -1) @@ -69,7 +70,55 @@ class CudaAttentionBackend(AttentionBackend): def decode(self, attn_metadata: AttentionMetaData, **kwargs): fd_inter_tensor = kwargs.get("fd_inter_tensor", None) output_tensor = attn_metadata.output_tensor - inference_ops.flash_decoding_attention( + self.inference_ops.flash_decoding_attention( + output_tensor, + attn_metadata.query_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.block_tables, + attn_metadata.block_size, + attn_metadata.kv_seq_len, + fd_inter_tensor.mid_output, + fd_inter_tensor.exp_sums, + fd_inter_tensor.max_logits, + attn_metadata.alibi_slopes, + attn_metadata.sm_scale, + ) + return output_tensor + + +class CudaAttentionBackend(AttentionBackend): + """ + Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found, + it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding. + """ + + def __init__(self): + super().__init__() + self.inference_ops = InferenceOpsLoader().load() + + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + return context_attention_unpadded( + q=attn_metadata.query_states, + k=attn_metadata.key_states, + v=attn_metadata.value_states, + k_cache=attn_metadata.k_cache, + v_cache=attn_metadata.v_cache, + context_lengths=attn_metadata.sequence_lengths, + block_tables=attn_metadata.block_tables, + block_size=attn_metadata.block_size, + output=attn_metadata.output_tensor, + alibi_slopes=attn_metadata.alibi_slopes, + max_seq_len=attn_metadata.kv_seq_len, + sm_scale=attn_metadata.sm_scale, + use_new_kcache_layout=True, # use new k cache layout for cuda kernels in this triton op + ) + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + fd_inter_tensor = kwargs.get("fd_inter_tensor", None) + output_tensor = attn_metadata.output_tensor + self.inference_ops.flash_decoding_attention( output_tensor, attn_metadata.query_states, attn_metadata.k_cache, @@ -88,6 +137,10 @@ class CudaAttentionBackend(AttentionBackend): class TritonAttentionBackend(AttentionBackend): + """ + Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding. + """ + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): return context_attention_unpadded( q=attn_metadata.query_states, @@ -102,7 +155,7 @@ class TritonAttentionBackend(AttentionBackend): alibi_slopes=attn_metadata.alibi_slopes, max_seq_len=attn_metadata.kv_seq_len, sm_scale=attn_metadata.sm_scale, - use_new_kcache_layout=attn_metadata.use_cuda_kernel, + use_new_kcache_layout=False, ) def decode(self, attn_metadata: AttentionMetaData, **kwargs): @@ -126,17 +179,24 @@ class TritonAttentionBackend(AttentionBackend): def get_attention_backend( - use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype + model_shard_infer_config: ModelShardInferenceConfig, ) -> AttentionBackend: """ - Get the attention backend based on the inference configurations. Only when: + Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend + for attention module calculation only when: 1. using CUDA kernel (use_cuda_kernel=True) 2. can use flash attention (flash-attn installed and dtype is fp16 or bf16) 3. not using speculative decoding (currently cuda kernel not support speculative decoding) - will the CUDA-kernel-based backend be used for attention layer computations. Otherwise, use Triton attention backend. + Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True, + the Triton backend will use a new k cache layout for Triton kernels. """ - use_flash_attn = can_use_flash_attn2(dtype) - if use_cuda_kernel and use_flash_attn and not use_spec_dec: - return CudaAttentionBackend() - else: + # Currently only triton kernels support speculative decoding + if model_shard_infer_config.use_spec_dec: return TritonAttentionBackend() + + if model_shard_infer_config.use_cuda_kernel: + if model_shard_infer_config.use_flash_attn: + return FlashAttentionBackend() + return CudaAttentionBackend() + + return TritonAttentionBackend() diff --git a/colossalai/inference/modeling/backends/pre_attention_backend.py b/colossalai/inference/modeling/backends/pre_attention_backend.py index 73cf32592..d8911cb23 100644 --- a/colossalai/inference/modeling/backends/pre_attention_backend.py +++ b/colossalai/inference/modeling/backends/pre_attention_backend.py @@ -1,18 +1,9 @@ from abc import ABC, abstractmethod -import torch -from colossalai.inference.utils import can_use_flash_attn2 -from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.inference.config import ModelShardInferenceConfig from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData -from colossalai.logging import get_dist_logger -from colossalai.kernel.triton import ( - copy_k_to_blocked_cache, - decoding_fused_rotary_embedding, - rotary_embedding, -) - -logger = get_dist_logger(__name__) -inference_ops = InferenceOpsLoader().load() +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding class PreAttentionBackend(ABC): @@ -25,17 +16,25 @@ class PreAttentionBackend(ABC): raise NotImplementedError -class CudaPreAttentionBackend(PreAttentionBackend): +class FlashPreAttentionBackend(PreAttentionBackend): + """ + FlashPreAttentionBackend handles KV cache initialization and positional encoding for FlashAttentionBackend. + """ + + def __init__(self): + super().__init__() + self.inference_ops = InferenceOpsLoader().load() + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): if not attn_metadata.use_alibi_attn: - inference_ops.rotary_embedding( + self.inference_ops.rotary_embedding( attn_metadata.query_states, attn_metadata.key_states, kwargs.get("cos", None), kwargs.get("sin", None), kwargs.get("high_precision", False), ) - inference_ops.context_kv_cache_memcpy( + self.inference_ops.context_kv_cache_memcpy( attn_metadata.key_states, attn_metadata.value_states, attn_metadata.k_cache, @@ -48,7 +47,7 @@ class CudaPreAttentionBackend(PreAttentionBackend): def decode(self, attn_metadata: AttentionMetaData, **kwargs): if not attn_metadata.use_alibi_attn: - inference_ops.rotary_embedding_and_cache_copy( + self.inference_ops.rotary_embedding_and_cache_copy( attn_metadata.query_states, attn_metadata.key_states, attn_metadata.value_states, @@ -61,7 +60,50 @@ class CudaPreAttentionBackend(PreAttentionBackend): kwargs.get("high_precision", None), ) else: - inference_ops.decode_kv_cache_memcpy( + self.inference_ops.decode_kv_cache_memcpy( + attn_metadata.key_states, + attn_metadata.value_states, + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.block_tables, + ) + + +class CudaPreAttentionBackend(PreAttentionBackend): + """ + CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend. + """ + + def __init__(self): + super().__init__() + self.inference_ops = InferenceOpsLoader().load() + + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): + if not attn_metadata.use_alibi_attn: + rotary_embedding( + attn_metadata.query_states, + attn_metadata.key_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + ) + + def decode(self, attn_metadata: AttentionMetaData, **kwargs): + if not attn_metadata.use_alibi_attn: + self.inference_ops.rotary_embedding_and_cache_copy( + attn_metadata.query_states, + attn_metadata.key_states, + attn_metadata.value_states, + kwargs.get("cos", None), + kwargs.get("sin", None), + attn_metadata.k_cache, + attn_metadata.v_cache, + attn_metadata.sequence_lengths, + attn_metadata.block_tables, + kwargs.get("high_precision", None), + ) + else: + self.inference_ops.decode_kv_cache_memcpy( attn_metadata.key_states, attn_metadata.value_states, attn_metadata.k_cache, @@ -72,6 +114,10 @@ class CudaPreAttentionBackend(PreAttentionBackend): class TritonPreAttentionBackend(PreAttentionBackend): + """ + TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend. + """ + def prefill(self, attn_metadata: AttentionMetaData, **kwargs): if not attn_metadata.use_alibi_attn: rotary_embedding( @@ -94,7 +140,7 @@ class TritonPreAttentionBackend(PreAttentionBackend): attn_metadata.block_tables, attn_metadata.sequence_lengths, ) - else: # else if using speculative decoding + else: # else if using speculative decoding if not attn_metadata.use_alibi_attn: rotary_embedding( attn_metadata.query_states, @@ -119,13 +165,18 @@ class TritonPreAttentionBackend(PreAttentionBackend): def get_pre_attention_backend( - use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype + model_shard_infer_config: ModelShardInferenceConfig, ) -> PreAttentionBackend: """ - Get the backend for pre-attention computations, including potisional encoding like RoPE and KV cache initialization. + Get the backend for pre-attention computations, including potisional encoding like + RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend. """ - use_flash_attn = can_use_flash_attn2(dtype) - if use_cuda_kernel and use_flash_attn and not use_spec_dec: - return CudaPreAttentionBackend() - else: + if model_shard_infer_config.use_spec_dec: return TritonPreAttentionBackend() + + if model_shard_infer_config.use_cuda_kernel: + if model_shard_infer_config.use_flash_attn: + return FlashPreAttentionBackend() + return CudaPreAttentionBackend() + + return TritonPreAttentionBackend() diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index d722c80ea..f10ef6e3c 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -1,31 +1,23 @@ # This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py import itertools -import math from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.distributed import ProcessGroup +from colossalai.inference.config import ModelShardInferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors -from colossalai.inference.utils import get_alibi_slopes -from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend +from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.kernel_loader import InferenceOpsLoader -from colossalai.kernel.triton import ( - context_attention_unpadded, - copy_k_to_blocked_cache, - decoding_fused_rotary_embedding, - flash_decoding_attention, - rms_layernorm, - rotary_embedding, -) +from colossalai.kernel.triton import rms_layernorm from colossalai.logging import get_dist_logger from colossalai.shardformer.layer.parallel_module import ParallelModule from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor - inference_ops = InferenceOpsLoader().load() logger = get_dist_logger(__name__) @@ -69,6 +61,7 @@ class NopadBaichuanAttention(ParallelModule): attn_oproj: ParallelModule = None, num_heads: int = None, hidden_size: int = None, + model_shard_infer_config: ModelShardInferenceConfig = None, process_group: ProcessGroup = None, helper_layout: Layout = None, ): @@ -93,6 +86,9 @@ class NopadBaichuanAttention(ParallelModule): self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) self.helper_layout = helper_layout + self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel + self.attention_backend = get_attention_backend(model_shard_infer_config) + self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config) self.alibi_slopes = None self.use_alibi_attn = False @@ -122,6 +118,7 @@ class NopadBaichuanAttention(ParallelModule): attn_kproj_w = k_proj_w attn_vproj_w = v_proj_w attn_oproj = module.o_proj + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) helper_layout = ( module.W_pack.weight.dist_layout @@ -133,6 +130,7 @@ class NopadBaichuanAttention(ParallelModule): attn_kproj_w=attn_kproj_w, attn_vproj_w=attn_vproj_w, attn_oproj=attn_oproj, + model_shard_infer_config=model_shard_infer_config, num_heads=module.num_heads, hidden_size=module.hidden_size, process_group=process_group, @@ -201,7 +199,6 @@ class NopadBaichuanAttention(ParallelModule): kv_seq_len: int = 0, output_tensor: torch.Tensor = None, sm_scale: int = None, - use_cuda_kernel: bool = True, cu_seqlens: torch.Tensor = None, high_precision: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -220,7 +217,6 @@ class NopadBaichuanAttention(ParallelModule): kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None. - use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ @@ -233,7 +229,7 @@ class NopadBaichuanAttention(ParallelModule): ) block_size = k_cache.size(-2) - + attn_metadata = AttentionMetaData( query_states=query_states, key_states=key_states, @@ -250,35 +246,31 @@ class NopadBaichuanAttention(ParallelModule): output_tensor=output_tensor, use_spec_dec=is_verifier, use_alibi_attn=self.use_alibi_attn, - use_cuda_kernel=use_cuda_kernel, ) - - attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype) - pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype) - + if is_prompts: # prefilling stage - pre_attention_backend.prefill( + self.pre_attention_backend.prefill( attn_metadata, cos=cos_sin[0], sin=cos_sin[1], high_precision=high_precision, ) - attn_output = attention_backend.prefill( + attn_output = self.attention_backend.prefill( attn_metadata, token_nums=token_nums, - ) - else: # decoding stage + ) + else: # decoding stage q_len = tokens_to_verify + 1 if is_verifier else 1 - - pre_attention_backend.decode( + + self.pre_attention_backend.decode( attn_metadata, cos=cos_sin[0], sin=cos_sin[1], q_len=q_len, ) - attn_output = attention_backend.decode( - attn_metadata, - fd_inter_tensor=fd_inter_tensor, + attn_output = self.attention_backend.decode( + attn_metadata, + fd_inter_tensor=fd_inter_tensor, q_len=q_len, ) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index e9346017a..e274e7b7c 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -16,21 +16,13 @@ from transformers.models.llama.modeling_llama import ( LlamaRMSNorm, ) -from colossalai.inference.config import InputMetaData +from colossalai.inference.config import InputMetaData, ModelShardInferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors -from colossalai.inference.modeling.backends.attention_backend import get_attention_backend, AttentionMetaData +from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend from colossalai.inference.utils import can_use_flash_attn2 from colossalai.kernel.kernel_loader import InferenceOpsLoader -from colossalai.kernel.triton import ( - context_attention_unpadded, - copy_k_to_blocked_cache, - decoding_fused_rotary_embedding, - flash_decoding_attention, - get_xine_cache, - rms_layernorm, - rotary_embedding, -) +from colossalai.kernel.triton import get_xine_cache, rms_layernorm from colossalai.logging import get_dist_logger from colossalai.shardformer.layer.parallel_module import ParallelModule from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor @@ -233,7 +225,6 @@ def llama_decoder_layer_forward( kv_seq_len=kv_seq_len, output_tensor=output_tensor, sm_scale=sm_scale, - use_cuda_kernel=use_cuda_kernel, cu_seqlens=cu_seqlens, high_precision=high_precision, ) @@ -397,6 +388,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): attn_vproj_w: torch.Tensor = None, attn_oproj: ParallelModule = None, process_group: ProcessGroup = None, + model_shard_infer_config: ModelShardInferenceConfig = None, num_heads: int = None, hidden_size: int = None, num_key_value_heads: int = None, @@ -428,6 +420,9 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): self.rope_theta = config.rope_theta self.is_causal = True + self.attention_backend = get_attention_backend(model_shard_infer_config) + self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config) + if self.num_heads == self.num_key_value_heads: qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)] self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) @@ -457,6 +452,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): attn_vproj_w = module.v_proj.weight assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor" attn_oproj = module.o_proj + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) attn_layer = NopadLlamaAttention( config=config, @@ -466,6 +462,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): attn_vproj_w=attn_vproj_w, attn_oproj=attn_oproj, process_group=process_group, + model_shard_infer_config=model_shard_infer_config, num_heads=module.num_heads, hidden_size=module.hidden_size, num_key_value_heads=module.num_key_value_heads, @@ -527,7 +524,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): ) block_size = k_cache.size(-2) - + attn_metadata = AttentionMetaData( query_states=query_states, key_states=key_states, @@ -544,38 +541,34 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): output_tensor=output_tensor, use_spec_dec=is_verifier, use_alibi_attn=False, - use_cuda_kernel=use_cuda_kernel, ) - - attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype) - pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype) - + if is_prompts: # prefilling stage - pre_attention_backend.prefill( + self.pre_attention_backend.prefill( attn_metadata, cos=cos_sin[0], sin=cos_sin[1], high_precision=high_precision, ) - attn_output = attention_backend.prefill( + attn_output = self.attention_backend.prefill( attn_metadata, token_nums=token_nums, - ) - else: # decoding stage + ) + else: # decoding stage q_len = tokens_to_verify + 1 if is_verifier else 1 - - pre_attention_backend.decode( + + self.pre_attention_backend.decode( attn_metadata, cos=cos_sin[0], sin=cos_sin[1], q_len=q_len, ) - attn_output = attention_backend.decode( - attn_metadata, - fd_inter_tensor=fd_inter_tensor, + attn_output = self.attention_backend.decode( + attn_metadata, + fd_inter_tensor=fd_inter_tensor, num_key_value_groups=self.num_key_value_groups, q_len=q_len, - ) + ) attn_output = attn_output.view(-1, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -633,4 +626,3 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): def extra_repr(self) -> str: return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False" - \ No newline at end of file diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py index 78268d6e7..b28c2fce8 100644 --- a/colossalai/inference/modeling/policy/nopadding_baichuan.py +++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py @@ -70,6 +70,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): SubModuleReplacementDescription( suffix="self_attn", target_module=NopadBaichuanAttention, + kwargs={ + "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], + }, ), ], ) diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py index 24cf7c740..0b6797560 100644 --- a/colossalai/inference/modeling/policy/nopadding_llama.py +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -72,6 +72,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): SubModuleReplacementDescription( suffix="self_attn", target_module=NopadLlamaAttention, + kwargs={ + "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], + }, ), ], ) diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index c3f5b4940..1374103a9 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -1,17 +1,17 @@ """ Utils for model inference """ +import math import os import re -import math from pathlib import Path from typing import Optional, Tuple import torch from torch import nn -from colossalai.testing import free_port from colossalai.logging import get_dist_logger +from colossalai.testing import free_port logger = get_dist_logger(__name__) @@ -122,11 +122,11 @@ def find_available_ports(num: int): def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: """ Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57 - + Args: num_heads (int): The number of attention heads. device (torch.device): The device to use. - + Returns: torch.Tensor: The Alibi slopes. """ @@ -142,20 +142,17 @@ def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes - - + + def can_use_flash_attn2(dtype: torch.dtype) -> bool: """ Check flash attention2 availability. """ if dtype not in (torch.float16, torch.bfloat16): - logger.warning(f"Flash attn2 currently only supports float16 and bfloat16.") return False - + try: - from flash_attn import __version__ - logger.info(f"flash_attn2 version {__version__}.") return True except ImportError: logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") - return False \ No newline at end of file + return False diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py index 736fab5ff..f24e1bb3f 100644 --- a/tests/test_infer/test_models/test_baichuan.py +++ b/tests/test_infer/test_models/test_baichuan.py @@ -55,7 +55,7 @@ def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=Fa assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() - generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len) outputs = inference_engine.generate(generation_config=generation_config) else: if prompt_template: