[test] fixed gemini plugin test (#3411)

* [test] fixed gemini plugin test

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-03 17:12:22 +08:00
committed by GitHub
parent 30412866e0
commit 638a07a7f9
7 changed files with 124 additions and 131 deletions

View File

@@ -1,6 +1,9 @@
from dataclasses import dataclass
from typing import List
import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp
from .region import Region
@@ -12,6 +15,7 @@ class NodeInfo:
runtime_fwd_mem: float = 0
runtime_bwd_mem: float = 0
class NvDevicePower:
"""
NVIDIA GPU computing performance (TFLOPs).
@@ -30,12 +34,14 @@ class NvDevicePower:
A100_FP32 = 19.5
class GlobalRuntimeInfo:
h2d_stream = torch.cuda.Stream()
d2h_stream = torch.cuda.Stream()
fwd_prefetch_event_map = {}
bwd_prefetch_event_map = {}
region_list = []
class GlobalRuntimeInfo(metaclass=SingletonMeta):
def __init__(self):
self.h2d_stream = torch.cuda.Stream()
self.d2h_stream = torch.cuda.Stream()
self.fwd_prefetch_event_map = {}
self.bwd_prefetch_event_map = {}
self.region_list = []
def compute_act_peak_mem(region_list: List[Region]) -> float:
@@ -70,21 +76,24 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
return act_peak_mem
def compute_max_param_mem(region_list: List[Region]) -> float:
return max(region.param_size for region in region_list)
def compute_total_param_mem(region_list: List[Region]) -> float:
return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid)
def requires_upload_p_in_fwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
and shared_reg.need_offload)
def requires_release_p_in_bwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
and shared_reg.need_offload)
def requires_offload_g_in_bwd(region: Region):
return region.param_size and (region.r_id <= region.shared_rid)