[fix] pytest and fix dyn grid bug

This commit is contained in:
Runyu Lu
2024-03-13 17:28:32 +08:00
parent 633e95b301
commit 1821a6dab0
4 changed files with 135 additions and 8 deletions

View File

@@ -118,6 +118,10 @@ class InferenceEngine:
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
self.graph_block_tables[0, :] = np.arange(
0, max_num_blocks
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
@@ -127,6 +131,10 @@ class InferenceEngine:
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
sequence_lengths[0] = torch.tensor(
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
@@ -385,6 +393,13 @@ class InferenceEngine:
head_dim=batch.head_dim,
)
# if not batch.is_prompts:
# self.logger.info(f"decoding")
# self.logger.info(f"input metadata is: {input_meta_data}")
# else:
# self.logger.info(f"prefill")
# self.logger.info(f"input metadata is: {input_meta_data}")
return input_ids, output_tensor, input_meta_data
def step(self) -> List[str]:
@@ -414,6 +429,9 @@ class InferenceEngine:
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
# logits_ = self.model(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
# assert torch.all(logits == logits_), f"error! not equal between origin model({logits_[-1]}) and CUDA Graph({logits[-1]})"
if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)