[feat] cuda graph support and refactor non-functional api

This commit is contained in:
Runyu Lu
2024-03-08 14:19:35 +08:00
parent 593a72e4d5
commit cefaeb5fdd
5 changed files with 281 additions and 43 deletions

View File

@@ -14,7 +14,6 @@ GibiByte = 1024**3
logger = logging.Logger(__name__)
_DTYPE_MAPPING = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
@@ -23,13 +22,37 @@ _DTYPE_MAPPING = {
_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
_DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
}
@dataclass
class InputMetaData:
"""The input info for a single step
Args:
block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None.
sequence_lengths (torch.Tensor): A tensor containing sequence lengths.
fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash decoding. Defaults to None.
batch_size (int, optional): The current batch size. Defaults to 64.
is_prompts (bool, optional): Indicates whether prefill or decoding. Defaults to False(decoding).
use_cuda_graph (bool, optional): Indicates whether to use the CUDA graph. Defaults to False.
kv_seq_len (int, optional): Key-value sequence length. Defaults to 512.
head_dim (int, optional): Head dimension. Defaults to 32.
"""
block_tables: torch.Tensor = None
sequence_lengths: torch.Tensor = None
fd_inter_tensor: torch.Tensor = None
batch_size: int = 64 # current_batch_size
is_prompts: bool = False
use_cuda_graph: bool = False
kv_seq_len: int = 512
head_dim: int = 32
@dataclass
class InferenceConfig:
"""The inference configuration.
@@ -55,6 +78,8 @@ class InferenceConfig:
pp_size (int): Pipeline parallel size, defaults to 1.
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
max_context_len_to_capture (int)
"""
@@ -90,6 +115,10 @@ class InferenceConfig:
micro_batch_size: int = 1
micro_batch_buffer_size: int = None
# cuda_graph
use_cuda_graph: bool = False
max_context_len_to_capture: int = max_input_len * max_output_len
def __post_init__(self):
self._verify_config()