[test] fixed gemini plugin test (#3411)

* [test] fixed gemini plugin test

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-03 17:12:22 +08:00
committed by GitHub
parent 30412866e0
commit 638a07a7f9
7 changed files with 124 additions and 131 deletions

View File

@@ -1,10 +1,11 @@
from typing import Optional, Set
from functools import partial
from typing import Optional, Set
import torch
import torch.nn as nn
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.gemini.tensor_utils import free_storage
from colossalai.nn.parallel.data_parallel import _cast_float
from .region_manager import RegionManager
from .util import GlobalRuntimeInfo
@@ -20,10 +21,7 @@ class BaseOffloadModule:
is_sync (bool): synchronous mode or not.
"""
def __init__(self,
model: nn.Module,
region_manager: RegionManager,
is_sync=True):
def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):
self.model = model
self.region_manager = region_manager
@@ -69,8 +67,8 @@ class BaseOffloadModule:
for p in self.model.parameters():
p.grad = None
GlobalRuntimeInfo.fwd_prefetch_event_map.clear()
GlobalRuntimeInfo.bwd_prefetch_event_map.clear()
GlobalRuntimeInfo().fwd_prefetch_event_map.clear()
GlobalRuntimeInfo().bwd_prefetch_event_map.clear()
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
@@ -82,7 +80,7 @@ class BaseOffloadModule:
self.overflow_counter += region.has_inf_or_nan
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.grad_offload_stream):
GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream)
GlobalRuntimeInfo().d2h_stream.wait_stream(master_stream)
region.move_grad_to_cpu()
return empty_grad

View File

@@ -1,4 +1,5 @@
from typing import Dict
import torch
import torch.fx
from torch.fx import GraphModule
@@ -7,10 +8,11 @@ from torch.utils._pytree import tree_map
from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from .region_manager import RegionManager
from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass
from .base_offload_module import BaseOffloadModule
from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo
from .region_manager import RegionManager
from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_pass
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem
def memory_optimize(model: torch.nn.Module,
inps: Dict[str, torch.Tensor],
@@ -29,13 +31,14 @@ def memory_optimize(model: torch.nn.Module,
region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget)
region_manager._build_regions()
GlobalRuntimeInfo.region_list = region_manager.region_list
GlobalRuntimeInfo().region_list = region_manager.region_list
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024**2
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024**2
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024**2
print(
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}")
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
)
if solver_name == 'syn':
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
@@ -45,5 +48,5 @@ def memory_optimize(model: torch.nn.Module,
raise TypeError(f"Unknown solver name {solver_name}!")
gm.recompile()
optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn')
optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn')
return optimized_model

View File

@@ -1,4 +1,5 @@
from typing import List
import torch
from torch.fx.node import Node
@@ -23,13 +24,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
ctx.bwd_info = bwd_info
d2h_rid = fwd_info.get('d2h_rid', None)
if d2h_rid is not None:
free_region = GlobalRuntimeInfo.region_list[d2h_rid]
free_region = GlobalRuntimeInfo().region_list[d2h_rid]
assert isinstance(free_region, Region)
free_region.free_cuda_data()
h2d_rid = fwd_info.get('h2d_rid', None)
if h2d_rid is not None:
h2d_region = GlobalRuntimeInfo.region_list[h2d_rid]
h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(h2d_region, Region)
h2d_region.move_param_to_cuda()
@@ -40,7 +41,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
pref_region.move_param_to_cuda()
@@ -65,23 +66,22 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
sync_rid = fwd_info.get('sync_rid', None)
if sync_rid is not None:
prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get(
sync_rid, None)
prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
if prefetch_event:
prefetch_event.wait()
h2d_rid = fwd_info.get('h2d_rid', None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
pref_region.move_param_to_cuda()
prefetch_event = torch.cuda.Event()
prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
GlobalRuntimeInfo.fwd_prefetch_event_map[h2d_rid] = prefetch_event
prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
GlobalRuntimeInfo().fwd_prefetch_event_map[h2d_rid] = prefetch_event
return input_
@@ -90,10 +90,9 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
sync_rid = ctx.bwd_info.get('sync_rid', None)
if sync_rid is not None:
wait_region = GlobalRuntimeInfo.region_list[sync_rid]
wait_region = GlobalRuntimeInfo().region_list[sync_rid]
assert isinstance(wait_region, Region)
prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get(
sync_rid, None)
prefetch_event = GlobalRuntimeInfo().bwd_prefetch_event_map.get(sync_rid, None)
if prefetch_event:
prefetch_event.wait()
else:
@@ -101,16 +100,16 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
pref_region.move_param_to_cuda()
prefetch_event = torch.cuda.Event()
prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
GlobalRuntimeInfo.bwd_prefetch_event_map[h2d_rid] = prefetch_event
prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
GlobalRuntimeInfo().bwd_prefetch_event_map[h2d_rid] = prefetch_event
return grad_output, None, None
@@ -129,6 +128,7 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
'''
Convert Prefetch and Offload operation into runtime action.
@@ -189,7 +189,8 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
new_node = mod_graph.create_node('call_function',
convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info))
replace_node_users(last_inp_node, new_node)
@@ -206,11 +207,11 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
# upload parameters of the first region
last_inp_node = tuple(mod_graph.nodes)[0]
first_region_with_p = [
region for region in region_list if region.param_size][0]
first_region_with_p = [region for region in region_list if region.param_size][0]
fwd_info = {"h2d_rid": first_region_with_p.r_id}
with mod_graph.inserting_after(last_inp_node):
upload_apply_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
upload_apply_node = mod_graph.create_node('call_function',
convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, {}))
replace_node_users(last_inp_node, upload_apply_node)
last_inp_node = upload_apply_node
@@ -225,19 +226,20 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx-1].need_offload:
if r_idx > 0 and region_list[r_idx - 1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1
bwd_info = {}
# backward prefetch
if r_idx > 0 and region_list[r_idx-1].need_offload:
if r_idx > 0 and region_list[r_idx - 1].need_offload:
bwd_info['sync_rid'] = r_idx - 1
if r_idx > 0 and region_list[r_idx-1].bwd_prefetch_region:
bwd_info['h2d_rid'] = region_list[r_idx-1].bwd_prefetch_region.r_id
if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
new_node = mod_graph.create_node('call_function',
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info))
replace_node_users(last_inp_node, new_node)
@@ -246,7 +248,8 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
if region.bwd_prefetch_region:
bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
new_node = mod_graph.create_node('call_function',
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, {}, bwd_info))
replace_node_users(last_inp_node, new_node)
# gm.graph.print_tabular()

View File

@@ -1,6 +1,9 @@
from dataclasses import dataclass
from typing import List
import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp
from .region import Region
@@ -12,6 +15,7 @@ class NodeInfo:
runtime_fwd_mem: float = 0
runtime_bwd_mem: float = 0
class NvDevicePower:
"""
NVIDIA GPU computing performance (TFLOPs).
@@ -30,12 +34,14 @@ class NvDevicePower:
A100_FP32 = 19.5
class GlobalRuntimeInfo:
h2d_stream = torch.cuda.Stream()
d2h_stream = torch.cuda.Stream()
fwd_prefetch_event_map = {}
bwd_prefetch_event_map = {}
region_list = []
class GlobalRuntimeInfo(metaclass=SingletonMeta):
def __init__(self):
self.h2d_stream = torch.cuda.Stream()
self.d2h_stream = torch.cuda.Stream()
self.fwd_prefetch_event_map = {}
self.bwd_prefetch_event_map = {}
self.region_list = []
def compute_act_peak_mem(region_list: List[Region]) -> float:
@@ -70,21 +76,24 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
return act_peak_mem
def compute_max_param_mem(region_list: List[Region]) -> float:
return max(region.param_size for region in region_list)
def compute_total_param_mem(region_list: List[Region]) -> float:
return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid)
def requires_upload_p_in_fwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
and shared_reg.need_offload)
def requires_release_p_in_bwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
and shared_reg.need_offload)
def requires_offload_g_in_bwd(region: Region):
return region.param_size and (region.r_id <= region.shared_rid)