[feat] cuda graph support and refactor non-functional api

This commit is contained in:
Runyu Lu
2024-03-08 14:19:35 +08:00
parent 593a72e4d5
commit cefaeb5fdd
5 changed files with 281 additions and 43 deletions

View File

@@ -1,5 +1,7 @@
import copy
import time
from itertools import count
from typing import List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -7,7 +9,9 @@ import torch.nn as nn
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import InferenceConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
@@ -81,11 +85,89 @@ class InferenceEngine:
self.logger = get_dist_logger(__name__)
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?
self.counter = count()
self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")
self.capture_model(self.k_cache, self.v_cache)
@torch.inference_mode()
def capture_model(self, k_cache: torch.Tensor, v_cache: torch.Tensor):
assert self.use_cuda_graph, "please turn on the cuda graph"
if self.verbose:
self.logger.info("Colossal AI CUDA Graph Capture begin")
t_capture_begin = time.perf_counter()
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
block_size = self.inference_config.block_size
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list[-1:]):
batch_bucket_for_capture = copy.deepcopy(self.request_handler.running_bb)
batch_bucket_for_capture.fd_interm_tensor = self.request_handler.running_bb.fd_interm_tensor
if self.verbose:
self.logger.info(f"batch size {batch_size} graph capturing")
# generate dummy input
for i in range(batch_size):
sequence = Sequence(
i,
None,
input_tokens[i],
block_size,
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
)
sequence.output_token_id = [0] # only capture the graph of decoding
batch_bucket_for_capture.add_seq(sequence, alloc_block_table=block_tables[i])
input_data = self.prepare_input(batch_bucket_for_capture)
input_tokens_ids, output_tensor, inputmetadata = input_data
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens_ids,
output_tensor,
inputmetadata,
k_caches=k_cache,
v_caches=v_cache,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
t_capture_end = time.perf_counter()
if self.verbose:
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
def _verify_config(self) -> None:
"""
Verify the input config
@@ -278,13 +360,47 @@ class InferenceEngine:
)
self.request_handler.add_sequence(sequence)
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim),
dtype=batch.dtype,
device=batch.device,
)
else:
output_tensor = torch.zeros(
(batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
use_cuda_graph = True
input_meta_data = InputMetaData(
block_tables=batch.get_block_table_tensor(),
sequence_lengths=sequence_lengths,
fd_inter_tensor=batch.fd_inter_tensor,
batch_size=batch.current_batch_size,
is_prompts=batch.is_prompts,
use_cuda_graph=use_cuda_graph,
kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim,
)
return input_ids, output_tensor, input_meta_data
def step(self) -> List[str]:
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. Run model to generate the next token
3. Update waiting list and running list in RequestHandler and get finished sequences.
4. Decode and return finished sequences.
2. Get the input, inputinfo and output placeholder from the batchbucket
3. Run model to generate the next token
4. Update waiting list and running list in RequestHandler and get finished sequences.
5. Decode and return finished sequences.
Returns:
List[str]: Decoded finished sequences generated by one step.
@@ -292,12 +408,15 @@ class InferenceEngine:
batch = self.request_handler.schedule()
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = self.model(
batch,
self.k_cahce,
self.v_cache,
)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]