[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,7 +1,7 @@
import bisect
from typing import List, Dict
from collections import OrderedDict
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict, List
from torch.fx.node import Node
@@ -26,10 +26,7 @@ class TrainingSimulator(ABC):
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
"""
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
self.region_list = region_list
self.region_num = len(region_list)
@@ -87,11 +84,7 @@ class TrainingSimulator(ABC):
class SynTrainingSimulator(TrainingSimulator):
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
def execute(self):
@@ -115,8 +108,7 @@ class SynTrainingSimulator(TrainingSimulator):
self.runtime_mem += region.param_size
for node in region.nodes:
self.runtime_mem += calculate_fwd_tmp(node) + \
calculate_fwd_out(node)
self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
self.fwd_node_mem[node] = self.runtime_mem
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
@@ -141,18 +133,15 @@ class SynTrainingSimulator(TrainingSimulator):
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
self.runtime_mem -= calculate_fwd_out(node)
self.runtime_mem += node.meta['bwd_mem_tmp'] + \
node.meta['bwd_mem_out']
self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
self.bwd_node_mem[node] = self.runtime_mem
self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
calculate_fwd_tmp(node))
self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
@@ -160,12 +149,14 @@ class SynTrainingSimulator(TrainingSimulator):
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
self.runtime_mem -= user_node.meta['bwd_mem_out']
self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0:
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!")
raise ValueError(
f"region id: {region.r_id}, node name: {node.name}, "
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!"
)
# release parameter and offload gradient in region
if region.r_id == region.shared_rid:
@@ -177,23 +168,16 @@ class SynTrainingSimulator(TrainingSimulator):
class AsynTrainingSimulator(TrainingSimulator):
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
self.iter_end_time: int = 0
# the last computation execution period
self.last_comp: ExecutionPeriod = ExecutionPeriod(
start_time=0, end_time=0)
self.last_comp: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the last parameter prefetch execution period
self.last_h2d: ExecutionPeriod = ExecutionPeriod(
start_time=0, end_time=0)
self.last_h2d: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the last gradient offload execution period
self.last_d2h: ExecutionPeriod = ExecutionPeriod(
start_time=0, end_time=0)
self.last_d2h: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the forward computation execution period of the region
self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the forward parameter prefetch execution period of the region
@@ -204,10 +188,8 @@ class AsynTrainingSimulator(TrainingSimulator):
self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the gradient offload execution period of the region
# which is divided into those that are waiting and those that have been released
self.bwd_reg_to_offl_waiting: OrderedDict[int,
ExecutionPeriod] = OrderedDict()
self.bwd_reg_to_offl_freed: OrderedDict[int,
ExecutionPeriod] = OrderedDict()
self.bwd_reg_to_offl_waiting: OrderedDict[int, ExecutionPeriod] = OrderedDict()
self.bwd_reg_to_offl_freed: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the region buffer, which records regions that are offloaded but not released
self.reg_buffer_to_free: List[int] = []
@@ -217,10 +199,8 @@ class AsynTrainingSimulator(TrainingSimulator):
# the region execution flow,
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
# when the execution reaches the i-th region.
self.fwd_reg_flow = torch.zeros(
(self.region_num, self.region_num)).bool()
self.bwd_reg_flow = torch.zeros(
(self.region_num, self.region_num)).bool()
self.fwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
self.bwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
def execute(self):
"""
@@ -232,7 +212,7 @@ class AsynTrainingSimulator(TrainingSimulator):
for reg in self.region_list:
if reg.param_size and reg.r_id < self.region_num - 1:
for nr in self.region_list[reg.r_id + 1:]:
for nr in self.region_list[reg.r_id + 1 :]:
if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]):
reg.fwd_prefetch_region = nr
break
@@ -249,8 +229,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self.runtime_mem -= self.region_list[reg_id].param_size
self.bwd_reg_to_offl_waiting.clear()
self.iter_end_time = max(
self.last_comp.end_time, self.last_d2h.end_time)
self.iter_end_time = max(self.last_comp.end_time, self.last_d2h.end_time)
def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
"""
@@ -258,10 +237,8 @@ class AsynTrainingSimulator(TrainingSimulator):
"""
pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time)
pref_end_time = pref_start_time + \
2.0 * self._get_communication_overhead('h2d', region.param_size)
pref_ep = ExecutionPeriod(
start_time=pref_start_time, end_time=pref_end_time)
pref_end_time = pref_start_time + 2.0 * self._get_communication_overhead("h2d", region.param_size)
pref_ep = ExecutionPeriod(start_time=pref_start_time, end_time=pref_end_time)
if is_fwd:
self.fwd_reg_to_pref[region.r_id] = pref_ep
else:
@@ -276,18 +253,16 @@ class AsynTrainingSimulator(TrainingSimulator):
if is_fwd:
reg_to_comp = self.fwd_reg_to_comp
reg_to_pref = self.fwd_reg_to_pref
flop_key = 'fwd_flop'
flop_key = "fwd_flop"
else:
reg_to_comp = self.bwd_reg_to_comp
reg_to_pref = self.bwd_reg_to_pref
flop_key = 'bwd_flop'
comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(
region.r_id, ExecutionPeriod(0, 0)).end_time)
comp_end_time = comp_start_time + \
sum([self._get_computing_overhead(node.meta.get(flop_key, 0))
for node in region.nodes])
comp_ep = ExecutionPeriod(
start_time=comp_start_time, end_time=comp_end_time)
flop_key = "bwd_flop"
comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(region.r_id, ExecutionPeriod(0, 0)).end_time)
comp_end_time = comp_start_time + sum(
[self._get_computing_overhead(node.meta.get(flop_key, 0)) for node in region.nodes]
)
comp_ep = ExecutionPeriod(start_time=comp_start_time, end_time=comp_end_time)
reg_to_comp[region.r_id] = comp_ep
self.last_comp = comp_ep
@@ -297,10 +272,8 @@ class AsynTrainingSimulator(TrainingSimulator):
"""
offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time)
offl_end_time = offl_start_time + \
self._get_communication_overhead('d2h', region.param_size)
offl_ep = ExecutionPeriod(
start_time=offl_start_time, end_time=offl_end_time)
offl_end_time = offl_start_time + self._get_communication_overhead("d2h", region.param_size)
offl_ep = ExecutionPeriod(start_time=offl_start_time, end_time=offl_end_time)
self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep
self.last_d2h = offl_ep
@@ -332,20 +305,17 @@ class AsynTrainingSimulator(TrainingSimulator):
self.fwd_reg_flow[region.r_id, region.r_id] = True
else:
self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1]
self.fwd_reg_flow[region.r_id,
self.reg_buffer_to_free] = False
self.fwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
self.reg_buffer_to_free.clear()
# prefetch parameters of the next region
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
self.runtime_mem += fwd_prefetch_region.param_size
self.fwd_reg_flow[region.r_id,
fwd_prefetch_region.r_id] = True
self.fwd_reg_flow[region.r_id, fwd_prefetch_region.r_id] = True
for node in region.nodes:
self.runtime_mem += calculate_fwd_tmp(node) + \
calculate_fwd_out(node)
self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
@@ -354,8 +324,7 @@ class AsynTrainingSimulator(TrainingSimulator):
if region.need_offload:
self.runtime_mem -= region.param_size
assert len(
self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}'
assert len(self.reg_buffer_to_free) <= 1, f"{len(self.reg_buffer_to_free)}"
self.reg_buffer_to_free.append(region.r_id)
def _eval_bwd_cost_per_region(self, region: Region):
@@ -398,8 +367,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1]
else:
self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1]
self.bwd_reg_flow[region.r_id,
self.reg_buffer_to_free] = False
self.bwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
# free gradients in the buffer
while len(self.reg_buffer_to_free):
@@ -415,8 +383,7 @@ class AsynTrainingSimulator(TrainingSimulator):
bwd_prefetch_region = region.bwd_prefetch_region
if bwd_prefetch_region:
self.runtime_mem += bwd_prefetch_region.param_size
self.bwd_reg_flow[region.r_id,
bwd_prefetch_region.r_id] = True
self.bwd_reg_flow[region.r_id, bwd_prefetch_region.r_id] = True
# add the gradient of the parameter
if region.r_id < region.shared_rid:
@@ -426,10 +393,8 @@ class AsynTrainingSimulator(TrainingSimulator):
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
self.runtime_mem -= calculate_fwd_out(node)
self.runtime_mem += node.meta['bwd_mem_tmp'] + \
node.meta['bwd_mem_out']
self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
@@ -437,8 +402,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self.bwd_node_mem[node] = self.runtime_mem
self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
calculate_fwd_tmp(node))
self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
@@ -446,12 +410,14 @@ class AsynTrainingSimulator(TrainingSimulator):
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
self.runtime_mem -= user_node.meta['bwd_mem_out']
self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0:
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!")
raise ValueError(
f"region id: {region.r_id}, node name: {node.name}, "
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!"
)
# release parameters of the region
if requires_release_p_in_bwd(self.region_list[region.shared_rid]):