mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[test] fixed gemini plugin test (#3411)
* [test] fixed gemini plugin test * polish code * polish code
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp
|
||||
|
||||
from .region import Region
|
||||
@@ -12,6 +15,7 @@ class NodeInfo:
|
||||
runtime_fwd_mem: float = 0
|
||||
runtime_bwd_mem: float = 0
|
||||
|
||||
|
||||
class NvDevicePower:
|
||||
"""
|
||||
NVIDIA GPU computing performance (TFLOPs).
|
||||
@@ -30,12 +34,14 @@ class NvDevicePower:
|
||||
A100_FP32 = 19.5
|
||||
|
||||
|
||||
class GlobalRuntimeInfo:
|
||||
h2d_stream = torch.cuda.Stream()
|
||||
d2h_stream = torch.cuda.Stream()
|
||||
fwd_prefetch_event_map = {}
|
||||
bwd_prefetch_event_map = {}
|
||||
region_list = []
|
||||
class GlobalRuntimeInfo(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self):
|
||||
self.h2d_stream = torch.cuda.Stream()
|
||||
self.d2h_stream = torch.cuda.Stream()
|
||||
self.fwd_prefetch_event_map = {}
|
||||
self.bwd_prefetch_event_map = {}
|
||||
self.region_list = []
|
||||
|
||||
|
||||
def compute_act_peak_mem(region_list: List[Region]) -> float:
|
||||
@@ -70,21 +76,24 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
|
||||
|
||||
return act_peak_mem
|
||||
|
||||
|
||||
def compute_max_param_mem(region_list: List[Region]) -> float:
|
||||
return max(region.param_size for region in region_list)
|
||||
|
||||
|
||||
def compute_total_param_mem(region_list: List[Region]) -> float:
|
||||
return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid)
|
||||
|
||||
|
||||
def requires_upload_p_in_fwd(shared_reg: Region):
|
||||
return (shared_reg.r_id >= shared_reg.shared_rid) or (
|
||||
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
|
||||
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
|
||||
and shared_reg.need_offload)
|
||||
|
||||
|
||||
def requires_release_p_in_bwd(shared_reg: Region):
|
||||
return (shared_reg.r_id >= shared_reg.shared_rid) or (
|
||||
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
|
||||
return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
|
||||
and shared_reg.need_offload)
|
||||
|
||||
|
||||
def requires_offload_g_in_bwd(region: Region):
|
||||
return region.param_size and (region.r_id <= region.shared_rid)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user