mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[fx] add more meta_registry for MetaTensor execution. (#2000)
* [sc] add examples for auto checkpoint. * merge upstream * [fx] add more meta_registry for MetaTensor execution.
This commit is contained in:
@@ -18,13 +18,11 @@ def bench(gm: torch.fx.GraphModule,
|
||||
data_gen: Callable,
|
||||
num_steps: int = 5) -> Tuple[int, int]:
|
||||
"""Benchmarking a given graph module
|
||||
|
||||
Args:
|
||||
gm (torch.fx.GraphModule): The graph module to benchmark.
|
||||
criterion (torch.nn.Module): Loss function.
|
||||
data_gen (Callable): Data generator.
|
||||
num_steps (int, optional): Number of test steps. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: peak memory in MB and step time in MS.
|
||||
"""
|
||||
@@ -69,7 +67,6 @@ def bench_rotor(gm: torch.fx.GraphModule,
|
||||
start_factor: int = 4) -> Tuple[np.array, list, list]:
|
||||
"""Auto Checkpoint Rotor Algorithm benchmarking
|
||||
Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
|
||||
|
||||
Args:
|
||||
gm (torch.fx.GraphModule): The graph module to benchmark.
|
||||
criterion (torch.nn.Module): Loss function.
|
||||
@@ -79,7 +76,6 @@ def bench_rotor(gm: torch.fx.GraphModule,
|
||||
free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0].
|
||||
start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget
|
||||
will be free_memory / start_factor. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS).
|
||||
"""
|
||||
|
Reference in New Issue
Block a user