diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index e096956d3..b3d2bc7bd 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -117,7 +117,8 @@ class InferenceEngine: max_context_len_to_capture = self.inference_config.max_context_len_to_capture max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() - self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32) self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE)) self.graph_block_tables[0, :] = np.arange( 0, max_num_blocks diff --git a/tests/test_infer/test_cuda_graph.py b/tests/test_infer/test_cuda_graph.py index 0810c356a..9c1d5de1b 100644 --- a/tests/test_infer/test_cuda_graph.py +++ b/tests/test_infer/test_cuda_graph.py @@ -34,7 +34,9 @@ def check_inference_engine(use_cuda_graph=False, batch_size=32): prompts_token_ids = [] for i in range(batch_size): - prompts_token_ids.append(np.random.randint(low=0, high=100, size=random.randint(1, 1024)).tolist()) + prompts_token_ids.append( + np.random.randint(low=0, high=100, size=random.randint(1, max(1024 // batch_size, 32))).tolist() + ) input_len = 1024 output_len = 128