mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[test] fixed gemini plugin test (#3411)
* [test] fixed gemini plugin test * polish code * polish code
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
@@ -7,10 +8,11 @@ from torch.utils._pytree import tree_map
|
||||
from colossalai.fx import ColoTracer, is_compatible_with_meta
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
|
||||
from .region_manager import RegionManager
|
||||
from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass
|
||||
from .base_offload_module import BaseOffloadModule
|
||||
from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo
|
||||
from .region_manager import RegionManager
|
||||
from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_pass
|
||||
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem
|
||||
|
||||
|
||||
def memory_optimize(model: torch.nn.Module,
|
||||
inps: Dict[str, torch.Tensor],
|
||||
@@ -29,13 +31,14 @@ def memory_optimize(model: torch.nn.Module,
|
||||
|
||||
region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget)
|
||||
region_manager._build_regions()
|
||||
GlobalRuntimeInfo.region_list = region_manager.region_list
|
||||
GlobalRuntimeInfo().region_list = region_manager.region_list
|
||||
|
||||
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2
|
||||
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2
|
||||
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2
|
||||
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024**2
|
||||
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024**2
|
||||
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024**2
|
||||
print(
|
||||
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}")
|
||||
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
|
||||
)
|
||||
|
||||
if solver_name == 'syn':
|
||||
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
|
||||
@@ -45,5 +48,5 @@ def memory_optimize(model: torch.nn.Module,
|
||||
raise TypeError(f"Unknown solver name {solver_name}!")
|
||||
|
||||
gm.recompile()
|
||||
optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn')
|
||||
optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn')
|
||||
return optimized_model
|
||||
|
Reference in New Issue
Block a user