mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[fix] pytest and fix dyn grid bug
This commit is contained in:
@@ -27,8 +27,7 @@ class CUDAGraphRunner:
|
||||
assert self.graph is None
|
||||
|
||||
# run kernel once to cache the kernel, avoid stream capture error
|
||||
hidden_states = self.model(
|
||||
# batch,
|
||||
hidden_states_origin_model = self.model(
|
||||
input_tokens_ids,
|
||||
output_tensor,
|
||||
inputmetadata,
|
||||
@@ -41,7 +40,7 @@ class CUDAGraphRunner:
|
||||
# self.logger.info(f"begin capture model...")
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, pool=memory_pool):
|
||||
hidden_states = self.model(
|
||||
hidden_states_cuda_graph = self.model(
|
||||
input_tokens_ids,
|
||||
output_tensor,
|
||||
inputmetadata,
|
||||
@@ -52,15 +51,16 @@ class CUDAGraphRunner:
|
||||
|
||||
# Save the input and output buffers, because replay always uses the same virtual memory space
|
||||
self.input_buffers = {
|
||||
# "batch": batch,
|
||||
"input_tokens_ids": input_tokens_ids,
|
||||
"output_tensor": output_tensor,
|
||||
"block_tables": inputmetadata.block_tables,
|
||||
"sequence_lengths": inputmetadata.sequence_lengths,
|
||||
# "fd_inter_tensor_mid_output": inputmetadata.fd_inter_tensor._mid_output,
|
||||
# "fd_inter_tensor_mid_output_lse": inputmetadata.fd_inter_tensor._mid_output_lse,
|
||||
"k_caches": k_caches,
|
||||
"v_caches": v_caches,
|
||||
}
|
||||
self.output_buffers = {"logits": hidden_states}
|
||||
self.output_buffers = {"logits": hidden_states_cuda_graph}
|
||||
return
|
||||
|
||||
def forward(
|
||||
@@ -74,9 +74,18 @@ class CUDAGraphRunner:
|
||||
# Copy the input tensors to the input buffers.
|
||||
self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True)
|
||||
self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True)
|
||||
self.input_buffers["block_tables"].copy_(inputmetadata.block_tables, non_blocking=True)
|
||||
|
||||
# for flexible block_table
|
||||
self.input_buffers["block_tables"].fill_(-1)
|
||||
M, N = inputmetadata.block_tables.shape
|
||||
self.input_buffers["block_tables"][:M, :N].copy_(inputmetadata.block_tables, non_blocking=True)
|
||||
|
||||
self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True)
|
||||
|
||||
# we only have a global fd_inter_tensor so we don't need to copy them
|
||||
# self.input_buffers["fd_inter_tensor_mid_output"].copy_(inputmetadata.fd_inter_tensor.mid_output, non_blocking=True)
|
||||
# self.input_buffers["fd_inter_tensor_mid_output_lse"].copy_(inputmetadata.fd_inter_tensor.mid_output_lse, non_blocking=True)
|
||||
|
||||
# KV caches are fixed tensors, so we don't need to copy them.
|
||||
# self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True)
|
||||
# self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True)
|
||||
|
Reference in New Issue
Block a user