mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[fix] pytest and fix dyn grid bug
This commit is contained in:
@@ -118,6 +118,10 @@ class InferenceEngine:
|
||||
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
||||
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
|
||||
self.graph_block_tables[0, :] = np.arange(
|
||||
0, max_num_blocks
|
||||
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
output_tensor = torch.zeros(
|
||||
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
||||
@@ -127,6 +131,10 @@ class InferenceEngine:
|
||||
max_num_seqs = self.inference_config.max_batch_size
|
||||
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
||||
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
||||
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
sequence_lengths[0] = torch.tensor(
|
||||
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
|
||||
).cuda()
|
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
@@ -385,6 +393,13 @@ class InferenceEngine:
|
||||
head_dim=batch.head_dim,
|
||||
)
|
||||
|
||||
# if not batch.is_prompts:
|
||||
# self.logger.info(f"decoding")
|
||||
# self.logger.info(f"input metadata is: {input_meta_data}")
|
||||
# else:
|
||||
# self.logger.info(f"prefill")
|
||||
# self.logger.info(f"input metadata is: {input_meta_data}")
|
||||
|
||||
return input_ids, output_tensor, input_meta_data
|
||||
|
||||
def step(self) -> List[str]:
|
||||
@@ -414,6 +429,9 @@ class InferenceEngine:
|
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
|
||||
# logits_ = self.model(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
# assert torch.all(logits == logits_), f"error! not equal between origin model({logits_[-1]}) and CUDA Graph({logits[-1]})"
|
||||
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
|
Reference in New Issue
Block a user