[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,6 +1,6 @@
import time
from typing import List, Dict, Type
from abc import ABC, abstractmethod
from typing import Dict, List, Type
NOT_NVML = False
try:
@@ -10,10 +10,11 @@ except:
import torch
from torch.fx.node import Node
from colossalai.utils.cuda import get_current_device
from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator
from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
from .util import NodeInfo, NvDevicePower
@@ -49,19 +50,14 @@ class Solver(ABC):
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
"""
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0,
error_factor: float = 0.95) -> None:
def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None:
self.region_list = region_list
self.error_factor: float = error_factor
if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor
else:
self.memory_budget = torch.cuda.get_device_properties(
get_current_device()).total_memory * self.error_factor
self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power()
@@ -94,7 +90,7 @@ class Solver(ABC):
if extra_cost == 0:
# means data transfer overhead can be completely overlapped
return (float('inf'), total_mem_saving, peak_mem_saving)
return (float("inf"), total_mem_saving, peak_mem_saving)
return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving)
def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool:
@@ -122,9 +118,7 @@ class Solver(ABC):
self.best_ts = best_ts
self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem)
def _update_node_mem_info(self,
fwd_mem_info: Dict[Node, float],
bwd_mem_info: Dict[Node, float]):
def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]):
"""
Update the runtime memory information of the node.
@@ -134,12 +128,10 @@ class Solver(ABC):
"""
for node, mem in fwd_mem_info.items():
assert hasattr(node, 'node_info') and isinstance(
node.node_info, NodeInfo)
assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
node.node_info.runtime_fwd_mem = mem
for node, mem in bwd_mem_info.items():
assert hasattr(node, 'node_info') and isinstance(
node.node_info, NodeInfo)
assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
node.node_info.runtime_bwd_mem = mem
def _extract_computing_power(self):
@@ -159,12 +151,12 @@ class Solver(ABC):
return NvDevicePower.RTX3080_FP16 * units
elif device_name.__contains__("RTX 3090"):
return NvDevicePower.RTX3090_FP16 * units
elif device_name.__contains__('V100'):
elif device_name.__contains__("V100"):
return NvDevicePower.V100_FP16 * units
elif device_name.__contains__("A100"):
return NvDevicePower.A100_FP16 * units
else:
raise TypeError(f'Unknown NVIDIA GPU device name {device_name}')
raise TypeError(f"Unknown NVIDIA GPU device name {device_name}")
def _profile_bandwidth(self):
"""
@@ -172,9 +164,9 @@ class Solver(ABC):
using data volumes ranging from 1KB to 1GB.
"""
print('profiling bandwidth ......')
print("profiling bandwidth ......")
link_to_bandwidth = {}
links = ['h2d', 'd2h']
links = ["h2d", "d2h"]
for link in links:
t_size = 1024
@@ -182,24 +174,22 @@ class Solver(ABC):
# from 1KB to 1GB
for i in range(21):
if link == 'h2d':
src_tensor = torch.ones(
int(t_size), dtype=torch.int8, pin_memory=True)
dst_tensor = torch.ones(
(int(t_size)), dtype=torch.int8, device='cuda')
elif link == 'd2h':
src_tensor = torch.ones(
int(t_size), dtype=torch.int8, device='cuda')
dst_tensor = torch.ones(
(int(t_size)), dtype=torch.int8, pin_memory=True)
if link == "h2d":
src_tensor = torch.ones(int(t_size), dtype=torch.int8, pin_memory=True)
dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, device="cuda")
elif link == "d2h":
src_tensor = torch.ones(int(t_size), dtype=torch.int8, device="cuda")
dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, pin_memory=True)
def func():
dst_tensor.copy_(src_tensor)
size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3)
print(f'size: {t_size / 1024 ** 2:.3f} MB, '
f'{src_tensor.device.type}-to-{dst_tensor.device.type} '
f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s')
print(
f"size: {t_size / 1024 ** 2:.3f} MB, "
f"{src_tensor.device.type}-to-{dst_tensor.device.type} "
f"bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s"
)
t_size *= 2
@@ -208,10 +198,7 @@ class Solver(ABC):
class SynGreedySolver(Solver):
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0) -> None:
def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None:
super().__init__(region_list, memory_budget)
self.best_ts: SynTrainingSimulator = None
@@ -258,7 +245,8 @@ class SynGreedySolver(Solver):
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
)
def _call_solver_l2l(self):
"""
@@ -270,7 +258,6 @@ class SynGreedySolver(Solver):
region.is_syn = True
def _try_to_offload(self, offload_region: Region):
# record previous information
orig_need_offload = offload_region.need_offload
assert not orig_need_offload
@@ -297,23 +284,17 @@ class SynGreedySolver(Solver):
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
extra_comm_cost = 2.0 * \
ts._get_communication_overhead('h2d', offload_region.param_size)
extra_comm_cost = 2.0 * ts._get_communication_overhead("h2d", offload_region.param_size)
# the shared region needs to be moved twice
if offload_region.r_id < offload_region.shared_rid:
extra_comm_cost *= 2.0
profit = self._compute_offload_profit(
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class AsynGreedySolver(Solver):
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0,
search_window_size: int = 3):
def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3):
super().__init__(region_list, memory_budget)
self.search_window_size = search_window_size
@@ -331,7 +312,7 @@ class AsynGreedySolver(Solver):
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
self._update_state(ts)
print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB")
print("init peak memory", self.best_ts.peak_mem / 1024**2, "MB")
def _call_solver(self):
"""
@@ -358,18 +339,17 @@ class AsynGreedySolver(Solver):
best_pref_ts = None
# search when to prefetch the region offloaded
for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]:
for host_region in self.region_list[region.r_id + 1 : region.r_id + 1 + self.search_window_size]:
if host_region.bwd_prefetch_region is not None:
continue
temp_ts, profit = self._try_to_offload(
host_region, region)
temp_ts, profit = self._try_to_offload(host_region, region)
if self._compare_profit(profit, max_prefetch_profit):
region_to_region_map[region.r_id] = host_region
max_prefetch_profit = profit
best_pref_ts = temp_ts
if profit[0] == float('inf'):
if profit[0] == float("inf"):
break
if self._compare_profit(max_prefetch_profit, max_offload_profit):
@@ -392,7 +372,8 @@ class AsynGreedySolver(Solver):
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
)
region_to_region_map.clear()
@@ -452,7 +433,6 @@ class AsynGreedySolver(Solver):
peak_mem_saving = 0
while len(self.region_to_region_map) and peak_mem_saving <= 0:
max_profit = (0,)
best_ts = None
undo_host_region = None
@@ -464,8 +444,7 @@ class AsynGreedySolver(Solver):
assert offload_region.need_offload
assert not offload_region.is_syn
ts, profit = self._try_convert_to_syn_upload(host_region,
offload_region)
ts, profit = self._try_convert_to_syn_upload(host_region, offload_region)
if self._compare_profit(profit, max_profit):
undo_host_region = host_region
@@ -474,7 +453,7 @@ class AsynGreedySolver(Solver):
best_ts = ts
if best_ts is None:
raise NotImplementedError('repair error!')
raise NotImplementedError("repair error!")
assert not undo_offload_region.is_syn
undo_offload_region.is_syn = True
@@ -500,17 +479,13 @@ class AsynGreedySolver(Solver):
ts.execute()
extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0)
profit = self._compute_offload_profit(
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class SolverFactory:
solvers: Dict[str, Type[Solver]] = {
'syn': SynGreedySolver,
'asyn': AsynGreedySolver
}
solvers: Dict[str, Type[Solver]] = {"syn": SynGreedySolver, "asyn": AsynGreedySolver}
@staticmethod
def create(solver_name: str) -> Type[Solver]: