[fx] add more meta_registry for MetaTensor execution. (#2000)

* [sc] add examples for auto checkpoint.

* merge upstream

* [fx] add more meta_registry for MetaTensor execution.
This commit is contained in:
Super Daniel
2022-11-23 10:55:46 +08:00
committed by GitHub
parent d00d905b86
commit 2edbef13cc
4 changed files with 70 additions and 21 deletions

View File

@@ -18,13 +18,11 @@ def bench(gm: torch.fx.GraphModule,
data_gen: Callable,
num_steps: int = 5) -> Tuple[int, int]:
"""Benchmarking a given graph module
Args:
gm (torch.fx.GraphModule): The graph module to benchmark.
criterion (torch.nn.Module): Loss function.
data_gen (Callable): Data generator.
num_steps (int, optional): Number of test steps. Defaults to 5.
Returns:
Tuple[int, int]: peak memory in MB and step time in MS.
"""
@@ -69,7 +67,6 @@ def bench_rotor(gm: torch.fx.GraphModule,
start_factor: int = 4) -> Tuple[np.array, list, list]:
"""Auto Checkpoint Rotor Algorithm benchmarking
Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
Args:
gm (torch.fx.GraphModule): The graph module to benchmark.
criterion (torch.nn.Module): Loss function.
@@ -79,7 +76,6 @@ def bench_rotor(gm: torch.fx.GraphModule,
free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0].
start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget
will be free_memory / start_factor. Defaults to 4.
Returns:
Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS).
"""