[fix] multi graphs capture error

This commit is contained in:
Runyu Lu
2024-03-11 10:49:31 +08:00
parent cefaeb5fdd
commit b2c0d9ff2b
4 changed files with 27 additions and 30 deletions

View File

@@ -29,6 +29,8 @@ _supported_models = [
"LlamaForCausalLM",
]
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class InferenceEngine:
@@ -108,54 +110,49 @@ class InferenceEngine:
t_capture_begin = time.perf_counter()
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
block_size = self.inference_config.block_size
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
)
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list[-1:]):
batch_bucket_for_capture = copy.deepcopy(self.request_handler.running_bb)
batch_bucket_for_capture.fd_interm_tensor = self.request_handler.running_bb.fd_interm_tensor
for batch_size in reversed(batch_size_capture_list):
if self.verbose:
self.logger.info(f"batch size {batch_size} graph capturing")
# generate dummy input
for i in range(batch_size):
sequence = Sequence(
i,
None,
input_tokens[i],
block_size,
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
)
sequence.output_token_id = [0] # only capture the graph of decoding
batch_bucket_for_capture.add_seq(sequence, alloc_block_table=block_tables[i])
input_data = self.prepare_input(batch_bucket_for_capture)
input_tokens_ids, output_tensor, inputmetadata = input_data
input_meta_data = InputMetaData(
block_tables=block_tables[:batch_size],
sequence_lengths=sequence_lengths[:batch_size],
fd_inter_tensor=fd_inter_tensor,
batch_size=batch_size,
is_prompts=False,
use_cuda_graph=True,
kv_seq_len=sequence_lengths[:batch_size].max().item(),
head_dim=head_dim,
)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens_ids,
output_tensor,
inputmetadata,
input_tokens_ids[:batch_size],
output_tensor[:batch_size],
input_meta_data,
k_caches=k_cache,
v_caches=v_cache,
memory_pool=self.graph_memory_pool,
@@ -412,8 +409,10 @@ class InferenceEngine:
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
# self.logger.info("run cuda graph")
else:
model_executable = self.model
# self.logger.info("run original model")
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)