mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[feat] cuda graph support and refactor non-functional api
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import copy
|
||||
import time
|
||||
from itertools import count
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -7,7 +9,9 @@ import torch.nn as nn
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
@@ -81,11 +85,89 @@ class InferenceEngine:
|
||||
self.logger = get_dist_logger(__name__)
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
|
||||
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
||||
# DISCUSS maybe move this into batch info?
|
||||
|
||||
self.counter = count()
|
||||
|
||||
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
||||
if self.use_cuda_graph:
|
||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||
self.graph_memory_pool = None # Set during graph capture.
|
||||
if verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture on")
|
||||
|
||||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
||||
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture begin")
|
||||
|
||||
t_capture_begin = time.perf_counter()
|
||||
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||
|
||||
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
|
||||
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
||||
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
|
||||
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
max_num_seqs = self.inference_config.max_batch_size
|
||||
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list[-1:]):
|
||||
batch_bucket_for_capture = copy.deepcopy(self.request_handler.running_bb)
|
||||
batch_bucket_for_capture.fd_interm_tensor = self.request_handler.running_bb.fd_interm_tensor
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"batch size {batch_size} graph capturing")
|
||||
|
||||
# generate dummy input
|
||||
for i in range(batch_size):
|
||||
sequence = Sequence(
|
||||
i,
|
||||
None,
|
||||
input_tokens[i],
|
||||
block_size,
|
||||
None,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
self.inference_config.max_output_len,
|
||||
)
|
||||
sequence.output_token_id = [0] # only capture the graph of decoding
|
||||
batch_bucket_for_capture.add_seq(sequence, alloc_block_table=block_tables[i])
|
||||
|
||||
input_data = self.prepare_input(batch_bucket_for_capture)
|
||||
|
||||
input_tokens_ids, output_tensor, inputmetadata = input_data
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
graph_runner.capture(
|
||||
input_tokens_ids,
|
||||
output_tensor,
|
||||
inputmetadata,
|
||||
k_caches=k_cache,
|
||||
v_caches=v_cache,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[batch_size] = graph_runner
|
||||
|
||||
t_capture_end = time.perf_counter()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
||||
|
||||
def _verify_config(self) -> None:
|
||||
"""
|
||||
Verify the input config
|
||||
@@ -278,13 +360,47 @@ class InferenceEngine:
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
|
||||
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
||||
input_ids = batch.get_1D_inputs()
|
||||
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
if batch.is_prompts:
|
||||
output_tensor = torch.zeros(
|
||||
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim),
|
||||
dtype=batch.dtype,
|
||||
device=batch.device,
|
||||
)
|
||||
else:
|
||||
output_tensor = torch.zeros(
|
||||
(batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
|
||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
use_cuda_graph = False
|
||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||
use_cuda_graph = True
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=batch.get_block_table_tensor(),
|
||||
sequence_lengths=sequence_lengths,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
batch_size=batch.current_batch_size,
|
||||
is_prompts=batch.is_prompts,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
kv_seq_len=sequence_lengths.max().item(),
|
||||
head_dim=batch.head_dim,
|
||||
)
|
||||
|
||||
return input_ids, output_tensor, input_meta_data
|
||||
|
||||
def step(self) -> List[str]:
|
||||
"""
|
||||
In each step, do the follows:
|
||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||
2. Run model to generate the next token
|
||||
3. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||
4. Decode and return finished sequences.
|
||||
2. Get the input, inputinfo and output placeholder from the batchbucket
|
||||
3. Run model to generate the next token
|
||||
4. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||
5. Decode and return finished sequences.
|
||||
|
||||
Returns:
|
||||
List[str]: Decoded finished sequences generated by one step.
|
||||
@@ -292,12 +408,15 @@ class InferenceEngine:
|
||||
|
||||
batch = self.request_handler.schedule()
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||
logits = self.model(
|
||||
batch,
|
||||
self.k_cahce,
|
||||
self.v_cache,
|
||||
)
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
|
Reference in New Issue
Block a user