[fix] pytest and fix dyn grid bug

This commit is contained in:
Runyu Lu
2024-03-13 17:28:32 +08:00
parent 633e95b301
commit 1821a6dab0
4 changed files with 135 additions and 8 deletions

View File

@@ -10,6 +10,8 @@ import torch
import torch.distributed as dist
from transformers.generation import GenerationConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
GibiByte = 1024**3
logger = logging.Logger(__name__)
@@ -45,13 +47,16 @@ class InputMetaData:
block_tables: torch.Tensor = None
sequence_lengths: torch.Tensor = None
fd_inter_tensor: torch.Tensor = None
fd_inter_tensor: FDIntermTensors = None
batch_size: int = 64 # current_batch_size
is_prompts: bool = False
use_cuda_graph: bool = False
kv_seq_len: int = 512
head_dim: int = 32
def __repr__(self) -> str:
return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})"
@dataclass
class InferenceConfig:
@@ -117,9 +122,10 @@ class InferenceConfig:
# cuda_graph
use_cuda_graph: bool = False
max_context_len_to_capture: int = max_input_len * max_output_len
max_context_len_to_capture: int = 512
def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
self._verify_config()
def _verify_config(self) -> None:

View File

@@ -118,6 +118,10 @@ class InferenceEngine:
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
self.graph_block_tables[0, :] = np.arange(
0, max_num_blocks
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
@@ -127,6 +131,10 @@ class InferenceEngine:
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
sequence_lengths[0] = torch.tensor(
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
@@ -385,6 +393,13 @@ class InferenceEngine:
head_dim=batch.head_dim,
)
# if not batch.is_prompts:
# self.logger.info(f"decoding")
# self.logger.info(f"input metadata is: {input_meta_data}")
# else:
# self.logger.info(f"prefill")
# self.logger.info(f"input metadata is: {input_meta_data}")
return input_ids, output_tensor, input_meta_data
def step(self) -> List[str]:
@@ -414,6 +429,9 @@ class InferenceEngine:
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
# logits_ = self.model(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
# assert torch.all(logits == logits_), f"error! not equal between origin model({logits_[-1]}) and CUDA Graph({logits[-1]})"
if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)

View File

@@ -27,8 +27,7 @@ class CUDAGraphRunner:
assert self.graph is None
# run kernel once to cache the kernel, avoid stream capture error
hidden_states = self.model(
# batch,
hidden_states_origin_model = self.model(
input_tokens_ids,
output_tensor,
inputmetadata,
@@ -41,7 +40,7 @@ class CUDAGraphRunner:
# self.logger.info(f"begin capture model...")
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool):
hidden_states = self.model(
hidden_states_cuda_graph = self.model(
input_tokens_ids,
output_tensor,
inputmetadata,
@@ -52,15 +51,16 @@ class CUDAGraphRunner:
# Save the input and output buffers, because replay always uses the same virtual memory space
self.input_buffers = {
# "batch": batch,
"input_tokens_ids": input_tokens_ids,
"output_tensor": output_tensor,
"block_tables": inputmetadata.block_tables,
"sequence_lengths": inputmetadata.sequence_lengths,
# "fd_inter_tensor_mid_output": inputmetadata.fd_inter_tensor._mid_output,
# "fd_inter_tensor_mid_output_lse": inputmetadata.fd_inter_tensor._mid_output_lse,
"k_caches": k_caches,
"v_caches": v_caches,
}
self.output_buffers = {"logits": hidden_states}
self.output_buffers = {"logits": hidden_states_cuda_graph}
return
def forward(
@@ -74,9 +74,18 @@ class CUDAGraphRunner:
# Copy the input tensors to the input buffers.
self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True)
self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True)
self.input_buffers["block_tables"].copy_(inputmetadata.block_tables, non_blocking=True)
# for flexible block_table
self.input_buffers["block_tables"].fill_(-1)
M, N = inputmetadata.block_tables.shape
self.input_buffers["block_tables"][:M, :N].copy_(inputmetadata.block_tables, non_blocking=True)
self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True)
# we only have a global fd_inter_tensor so we don't need to copy them
# self.input_buffers["fd_inter_tensor_mid_output"].copy_(inputmetadata.fd_inter_tensor.mid_output, non_blocking=True)
# self.input_buffers["fd_inter_tensor_mid_output_lse"].copy_(inputmetadata.fd_inter_tensor.mid_output_lse, non_blocking=True)
# KV caches are fixed tensors, so we don't need to copy them.
# self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True)
# self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True)