mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[test] fixed gemini plugin test (#3411)
* [test] fixed gemini plugin test * polish code * polish code
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from typing import Optional, Set
|
||||
from functools import partial
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
from colossalai.gemini.tensor_utils import free_storage
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
|
||||
from .region_manager import RegionManager
|
||||
from .util import GlobalRuntimeInfo
|
||||
@@ -20,10 +21,7 @@ class BaseOffloadModule:
|
||||
is_sync (bool): synchronous mode or not.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
region_manager: RegionManager,
|
||||
is_sync=True):
|
||||
def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):
|
||||
|
||||
self.model = model
|
||||
self.region_manager = region_manager
|
||||
@@ -69,8 +67,8 @@ class BaseOffloadModule:
|
||||
for p in self.model.parameters():
|
||||
p.grad = None
|
||||
|
||||
GlobalRuntimeInfo.fwd_prefetch_event_map.clear()
|
||||
GlobalRuntimeInfo.bwd_prefetch_event_map.clear()
|
||||
GlobalRuntimeInfo().fwd_prefetch_event_map.clear()
|
||||
GlobalRuntimeInfo().bwd_prefetch_event_map.clear()
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
empty_grad = torch.empty_like(grad)
|
||||
@@ -82,7 +80,7 @@ class BaseOffloadModule:
|
||||
self.overflow_counter += region.has_inf_or_nan
|
||||
master_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(self.grad_offload_stream):
|
||||
GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream)
|
||||
GlobalRuntimeInfo().d2h_stream.wait_stream(master_stream)
|
||||
region.move_grad_to_cpu()
|
||||
return empty_grad
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
@@ -7,10 +8,11 @@ from torch.utils._pytree import tree_map
|
||||
from colossalai.fx import ColoTracer, is_compatible_with_meta
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
|
||||
from .region_manager import RegionManager
|
||||
from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass
|
||||
from .base_offload_module import BaseOffloadModule
|
||||
from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo
|
||||
from .region_manager import RegionManager
|
||||
from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_pass
|
||||
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem
|
||||
|
||||
|
||||
def memory_optimize(model: torch.nn.Module,
|
||||
inps: Dict[str, torch.Tensor],
|
||||
@@ -29,13 +31,14 @@ def memory_optimize(model: torch.nn.Module,
|
||||
|
||||
region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget)
|
||||
region_manager._build_regions()
|
||||
GlobalRuntimeInfo.region_list = region_manager.region_list
|
||||
GlobalRuntimeInfo().region_list = region_manager.region_list
|
||||
|
||||
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2
|
||||
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2
|
||||
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2
|
||||
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024**2
|
||||
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024**2
|
||||
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024**2
|
||||
print(
|
||||
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}")
|
||||
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
|
||||
)
|
||||
|
||||
if solver_name == 'syn':
|
||||
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
|
||||
@@ -45,5 +48,5 @@ def memory_optimize(model: torch.nn.Module,
|
||||
raise TypeError(f"Unknown solver name {solver_name}!")
|
||||
|
||||
gm.recompile()
|
||||
optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn')
|
||||
optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn')
|
||||
return optimized_model
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
@@ -23,13 +24,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
|
||||
ctx.bwd_info = bwd_info
|
||||
d2h_rid = fwd_info.get('d2h_rid', None)
|
||||
if d2h_rid is not None:
|
||||
free_region = GlobalRuntimeInfo.region_list[d2h_rid]
|
||||
free_region = GlobalRuntimeInfo().region_list[d2h_rid]
|
||||
assert isinstance(free_region, Region)
|
||||
free_region.free_cuda_data()
|
||||
|
||||
h2d_rid = fwd_info.get('h2d_rid', None)
|
||||
if h2d_rid is not None:
|
||||
h2d_region = GlobalRuntimeInfo.region_list[h2d_rid]
|
||||
h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
|
||||
assert isinstance(h2d_region, Region)
|
||||
h2d_region.move_param_to_cuda()
|
||||
|
||||
@@ -40,7 +41,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
|
||||
|
||||
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
|
||||
if h2d_rid is not None:
|
||||
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
|
||||
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
|
||||
assert isinstance(pref_region, Region)
|
||||
pref_region.move_param_to_cuda()
|
||||
|
||||
@@ -65,23 +66,22 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
|
||||
|
||||
sync_rid = fwd_info.get('sync_rid', None)
|
||||
if sync_rid is not None:
|
||||
prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get(
|
||||
sync_rid, None)
|
||||
prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
|
||||
if prefetch_event:
|
||||
prefetch_event.wait()
|
||||
|
||||
h2d_rid = fwd_info.get('h2d_rid', None)
|
||||
if h2d_rid is not None:
|
||||
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
|
||||
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
|
||||
assert isinstance(pref_region, Region)
|
||||
master_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
|
||||
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
|
||||
with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
|
||||
GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
|
||||
pref_region.move_param_to_cuda()
|
||||
|
||||
prefetch_event = torch.cuda.Event()
|
||||
prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
|
||||
GlobalRuntimeInfo.fwd_prefetch_event_map[h2d_rid] = prefetch_event
|
||||
prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
|
||||
GlobalRuntimeInfo().fwd_prefetch_event_map[h2d_rid] = prefetch_event
|
||||
|
||||
return input_
|
||||
|
||||
@@ -90,10 +90,9 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
|
||||
|
||||
sync_rid = ctx.bwd_info.get('sync_rid', None)
|
||||
if sync_rid is not None:
|
||||
wait_region = GlobalRuntimeInfo.region_list[sync_rid]
|
||||
wait_region = GlobalRuntimeInfo().region_list[sync_rid]
|
||||
assert isinstance(wait_region, Region)
|
||||
prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get(
|
||||
sync_rid, None)
|
||||
prefetch_event = GlobalRuntimeInfo().bwd_prefetch_event_map.get(sync_rid, None)
|
||||
if prefetch_event:
|
||||
prefetch_event.wait()
|
||||
else:
|
||||
@@ -101,16 +100,16 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
|
||||
|
||||
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
|
||||
if h2d_rid is not None:
|
||||
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
|
||||
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
|
||||
assert isinstance(pref_region, Region)
|
||||
master_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
|
||||
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
|
||||
with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
|
||||
GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
|
||||
pref_region.move_param_to_cuda()
|
||||
|
||||
prefetch_event = torch.cuda.Event()
|
||||
prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
|
||||
GlobalRuntimeInfo.bwd_prefetch_event_map[h2d_rid] = prefetch_event
|
||||
prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
|
||||
GlobalRuntimeInfo().bwd_prefetch_event_map[h2d_rid] = prefetch_event
|
||||
return grad_output, None, None
|
||||
|
||||
|
||||
@@ -129,6 +128,7 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
|
||||
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
|
||||
return ret
|
||||
|
||||
|
||||
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
|
||||
'''
|
||||
Convert Prefetch and Offload operation into runtime action.
|
||||
@@ -189,7 +189,8 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
|
||||
|
||||
if fwd_info or bwd_info:
|
||||
with mod_graph.inserting_after(last_inp_node):
|
||||
new_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
|
||||
new_node = mod_graph.create_node('call_function',
|
||||
convert_fwd_upload_bwd_offload_to_action,
|
||||
args=(last_inp_node, fwd_info, bwd_info))
|
||||
replace_node_users(last_inp_node, new_node)
|
||||
|
||||
@@ -206,11 +207,11 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
|
||||
|
||||
# upload parameters of the first region
|
||||
last_inp_node = tuple(mod_graph.nodes)[0]
|
||||
first_region_with_p = [
|
||||
region for region in region_list if region.param_size][0]
|
||||
first_region_with_p = [region for region in region_list if region.param_size][0]
|
||||
fwd_info = {"h2d_rid": first_region_with_p.r_id}
|
||||
with mod_graph.inserting_after(last_inp_node):
|
||||
upload_apply_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
|
||||
upload_apply_node = mod_graph.create_node('call_function',
|
||||
convert_fwd_upload_bwd_offload_to_action,
|
||||
args=(last_inp_node, fwd_info, {}))
|
||||
replace_node_users(last_inp_node, upload_apply_node)
|
||||
last_inp_node = upload_apply_node
|
||||
@@ -225,19 +226,20 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
|
||||
fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
|
||||
|
||||
# forward offload
|
||||
if r_idx > 0 and region_list[r_idx-1].need_offload:
|
||||
if r_idx > 0 and region_list[r_idx - 1].need_offload:
|
||||
fwd_info['d2h_rid'] = r_idx - 1
|
||||
|
||||
bwd_info = {}
|
||||
# backward prefetch
|
||||
if r_idx > 0 and region_list[r_idx-1].need_offload:
|
||||
if r_idx > 0 and region_list[r_idx - 1].need_offload:
|
||||
bwd_info['sync_rid'] = r_idx - 1
|
||||
if r_idx > 0 and region_list[r_idx-1].bwd_prefetch_region:
|
||||
bwd_info['h2d_rid'] = region_list[r_idx-1].bwd_prefetch_region.r_id
|
||||
if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
|
||||
bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id
|
||||
|
||||
if fwd_info or bwd_info:
|
||||
with mod_graph.inserting_after(last_inp_node):
|
||||
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
|
||||
new_node = mod_graph.create_node('call_function',
|
||||
convert_fwd_prefetch_bwd_offload_to_action,
|
||||
args=(last_inp_node, fwd_info, bwd_info))
|
||||
replace_node_users(last_inp_node, new_node)
|
||||
|
||||
@@ -246,7 +248,8 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
|
||||
if region.bwd_prefetch_region:
|
||||
bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
|
||||
with mod_graph.inserting_after(last_inp_node):
|
||||
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
|
||||
new_node = mod_graph.create_node('call_function',
|
||||
convert_fwd_prefetch_bwd_offload_to_action,
|
||||
args=(last_inp_node, {}, bwd_info))
|
||||
replace_node_users(last_inp_node, new_node)
|
||||
# gm.graph.print_tabular()
|
||||
|
@@ -1,6 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp
|
||||
|
||||
from .region import Region
|
||||
@@ -12,6 +15,7 @@ class NodeInfo:
|
||||
runtime_fwd_mem: float = 0
|
||||
runtime_bwd_mem: float = 0
|
||||
|
||||
|
||||
class NvDevicePower:
|
||||
"""
|
||||
NVIDIA GPU computing performance (TFLOPs).
|
||||
@@ -30,12 +34,14 @@ class NvDevicePower:
|
||||
A100_FP32 = 19.5
|
||||
|
||||
|
||||
class GlobalRuntimeInfo:
|
||||
h2d_stream = torch.cuda.Stream()
|
||||
d2h_stream = torch.cuda.Stream()
|
||||
fwd_prefetch_event_map = {}
|
||||
bwd_prefetch_event_map = {}
|
||||
region_list = []
|
||||
class GlobalRuntimeInfo(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self):
|
||||
self.h2d_stream = torch.cuda.Stream()
|
||||
self.d2h_stream = torch.cuda.Stream()
|
||||
self.fwd_prefetch_event_map = {}
|
||||
self.bwd_prefetch_event_map = {}
|
||||
self.region_list = []
|
||||
|
||||
|
||||
def compute_act_peak_mem(region_list: List[Region]) -> float:
|
||||
@@ -70,21 +76,24 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
|
||||
|
||||
return act_peak_mem
|
||||
|
||||
|
||||
def compute_max_param_mem(region_list: List[Region]) -> float:
|
||||
return max(region.param_size for region in region_list)
|
||||
|
||||
|
||||
def compute_total_param_mem(region_list: List[Region]) -> float:
|
||||
return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid)
|
||||
|
||||
|
||||
def requires_upload_p_in_fwd(shared_reg: Region):
|
||||
return (shared_reg.r_id >= shared_reg.shared_rid) or (
|
||||
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
|
||||
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
|
||||
and shared_reg.need_offload)
|
||||
|
||||
|
||||
def requires_release_p_in_bwd(shared_reg: Region):
|
||||
return (shared_reg.r_id >= shared_reg.shared_rid) or (
|
||||
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
|
||||
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
|
||||
and shared_reg.need_offload)
|
||||
|
||||
|
||||
def requires_offload_g_in_bwd(region: Region):
|
||||
return region.param_size and (region.r_id <= region.shared_rid)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user