[fix] multi graphs capture error

This commit is contained in:
Runyu Lu
2024-03-11 10:49:31 +08:00
parent cefaeb5fdd
commit b2c0d9ff2b
4 changed files with 27 additions and 30 deletions

View File

@@ -42,7 +42,6 @@ class CUDAGraphRunner:
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool):
hidden_states = self.model(
# batch,
input_tokens_ids,
output_tensor,
inputmetadata,