mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from typing import List, Dict, Type
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Type
|
||||
|
||||
NOT_NVML = False
|
||||
try:
|
||||
@@ -10,10 +10,11 @@ except:
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator
|
||||
from .region import Region
|
||||
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
|
||||
from .util import NodeInfo, NvDevicePower
|
||||
|
||||
|
||||
@@ -49,19 +50,14 @@ class Solver(ABC):
|
||||
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
region_list: List[Region],
|
||||
memory_budget: float = -1.0,
|
||||
error_factor: float = 0.95) -> None:
|
||||
|
||||
def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None:
|
||||
self.region_list = region_list
|
||||
|
||||
self.error_factor: float = error_factor
|
||||
if memory_budget > 0:
|
||||
self.memory_budget = memory_budget * self.error_factor
|
||||
else:
|
||||
self.memory_budget = torch.cuda.get_device_properties(
|
||||
get_current_device()).total_memory * self.error_factor
|
||||
self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
|
||||
|
||||
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
|
||||
self.comp_power: float = self._extract_computing_power()
|
||||
@@ -94,7 +90,7 @@ class Solver(ABC):
|
||||
|
||||
if extra_cost == 0:
|
||||
# means data transfer overhead can be completely overlapped
|
||||
return (float('inf'), total_mem_saving, peak_mem_saving)
|
||||
return (float("inf"), total_mem_saving, peak_mem_saving)
|
||||
return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving)
|
||||
|
||||
def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool:
|
||||
@@ -122,9 +118,7 @@ class Solver(ABC):
|
||||
self.best_ts = best_ts
|
||||
self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem)
|
||||
|
||||
def _update_node_mem_info(self,
|
||||
fwd_mem_info: Dict[Node, float],
|
||||
bwd_mem_info: Dict[Node, float]):
|
||||
def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]):
|
||||
"""
|
||||
Update the runtime memory information of the node.
|
||||
|
||||
@@ -134,12 +128,10 @@ class Solver(ABC):
|
||||
"""
|
||||
|
||||
for node, mem in fwd_mem_info.items():
|
||||
assert hasattr(node, 'node_info') and isinstance(
|
||||
node.node_info, NodeInfo)
|
||||
assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
|
||||
node.node_info.runtime_fwd_mem = mem
|
||||
for node, mem in bwd_mem_info.items():
|
||||
assert hasattr(node, 'node_info') and isinstance(
|
||||
node.node_info, NodeInfo)
|
||||
assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
|
||||
node.node_info.runtime_bwd_mem = mem
|
||||
|
||||
def _extract_computing_power(self):
|
||||
@@ -159,12 +151,12 @@ class Solver(ABC):
|
||||
return NvDevicePower.RTX3080_FP16 * units
|
||||
elif device_name.__contains__("RTX 3090"):
|
||||
return NvDevicePower.RTX3090_FP16 * units
|
||||
elif device_name.__contains__('V100'):
|
||||
elif device_name.__contains__("V100"):
|
||||
return NvDevicePower.V100_FP16 * units
|
||||
elif device_name.__contains__("A100"):
|
||||
return NvDevicePower.A100_FP16 * units
|
||||
else:
|
||||
raise TypeError(f'Unknown NVIDIA GPU device name {device_name}')
|
||||
raise TypeError(f"Unknown NVIDIA GPU device name {device_name}")
|
||||
|
||||
def _profile_bandwidth(self):
|
||||
"""
|
||||
@@ -172,9 +164,9 @@ class Solver(ABC):
|
||||
using data volumes ranging from 1KB to 1GB.
|
||||
"""
|
||||
|
||||
print('profiling bandwidth ......')
|
||||
print("profiling bandwidth ......")
|
||||
link_to_bandwidth = {}
|
||||
links = ['h2d', 'd2h']
|
||||
links = ["h2d", "d2h"]
|
||||
|
||||
for link in links:
|
||||
t_size = 1024
|
||||
@@ -182,24 +174,22 @@ class Solver(ABC):
|
||||
|
||||
# from 1KB to 1GB
|
||||
for i in range(21):
|
||||
if link == 'h2d':
|
||||
src_tensor = torch.ones(
|
||||
int(t_size), dtype=torch.int8, pin_memory=True)
|
||||
dst_tensor = torch.ones(
|
||||
(int(t_size)), dtype=torch.int8, device='cuda')
|
||||
elif link == 'd2h':
|
||||
src_tensor = torch.ones(
|
||||
int(t_size), dtype=torch.int8, device='cuda')
|
||||
dst_tensor = torch.ones(
|
||||
(int(t_size)), dtype=torch.int8, pin_memory=True)
|
||||
if link == "h2d":
|
||||
src_tensor = torch.ones(int(t_size), dtype=torch.int8, pin_memory=True)
|
||||
dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, device="cuda")
|
||||
elif link == "d2h":
|
||||
src_tensor = torch.ones(int(t_size), dtype=torch.int8, device="cuda")
|
||||
dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, pin_memory=True)
|
||||
|
||||
def func():
|
||||
dst_tensor.copy_(src_tensor)
|
||||
|
||||
size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3)
|
||||
print(f'size: {t_size / 1024 ** 2:.3f} MB, '
|
||||
f'{src_tensor.device.type}-to-{dst_tensor.device.type} '
|
||||
f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s')
|
||||
print(
|
||||
f"size: {t_size / 1024 ** 2:.3f} MB, "
|
||||
f"{src_tensor.device.type}-to-{dst_tensor.device.type} "
|
||||
f"bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s"
|
||||
)
|
||||
|
||||
t_size *= 2
|
||||
|
||||
@@ -208,10 +198,7 @@ class Solver(ABC):
|
||||
|
||||
|
||||
class SynGreedySolver(Solver):
|
||||
|
||||
def __init__(self,
|
||||
region_list: List[Region],
|
||||
memory_budget: float = -1.0) -> None:
|
||||
def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None:
|
||||
super().__init__(region_list, memory_budget)
|
||||
|
||||
self.best_ts: SynTrainingSimulator = None
|
||||
@@ -258,7 +245,8 @@ class SynGreedySolver(Solver):
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
|
||||
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
|
||||
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
|
||||
)
|
||||
|
||||
def _call_solver_l2l(self):
|
||||
"""
|
||||
@@ -270,7 +258,6 @@ class SynGreedySolver(Solver):
|
||||
region.is_syn = True
|
||||
|
||||
def _try_to_offload(self, offload_region: Region):
|
||||
|
||||
# record previous information
|
||||
orig_need_offload = offload_region.need_offload
|
||||
assert not orig_need_offload
|
||||
@@ -297,23 +284,17 @@ class SynGreedySolver(Solver):
|
||||
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
|
||||
ts.execute()
|
||||
|
||||
extra_comm_cost = 2.0 * \
|
||||
ts._get_communication_overhead('h2d', offload_region.param_size)
|
||||
extra_comm_cost = 2.0 * ts._get_communication_overhead("h2d", offload_region.param_size)
|
||||
# the shared region needs to be moved twice
|
||||
if offload_region.r_id < offload_region.shared_rid:
|
||||
extra_comm_cost *= 2.0
|
||||
profit = self._compute_offload_profit(
|
||||
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
|
||||
profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
|
||||
|
||||
return ts, profit
|
||||
|
||||
|
||||
class AsynGreedySolver(Solver):
|
||||
|
||||
def __init__(self,
|
||||
region_list: List[Region],
|
||||
memory_budget: float = -1.0,
|
||||
search_window_size: int = 3):
|
||||
def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3):
|
||||
super().__init__(region_list, memory_budget)
|
||||
|
||||
self.search_window_size = search_window_size
|
||||
@@ -331,7 +312,7 @@ class AsynGreedySolver(Solver):
|
||||
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
|
||||
ts.execute()
|
||||
self._update_state(ts)
|
||||
print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB")
|
||||
print("init peak memory", self.best_ts.peak_mem / 1024**2, "MB")
|
||||
|
||||
def _call_solver(self):
|
||||
"""
|
||||
@@ -358,18 +339,17 @@ class AsynGreedySolver(Solver):
|
||||
best_pref_ts = None
|
||||
|
||||
# search when to prefetch the region offloaded
|
||||
for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]:
|
||||
for host_region in self.region_list[region.r_id + 1 : region.r_id + 1 + self.search_window_size]:
|
||||
if host_region.bwd_prefetch_region is not None:
|
||||
continue
|
||||
|
||||
temp_ts, profit = self._try_to_offload(
|
||||
host_region, region)
|
||||
temp_ts, profit = self._try_to_offload(host_region, region)
|
||||
|
||||
if self._compare_profit(profit, max_prefetch_profit):
|
||||
region_to_region_map[region.r_id] = host_region
|
||||
max_prefetch_profit = profit
|
||||
best_pref_ts = temp_ts
|
||||
if profit[0] == float('inf'):
|
||||
if profit[0] == float("inf"):
|
||||
break
|
||||
|
||||
if self._compare_profit(max_prefetch_profit, max_offload_profit):
|
||||
@@ -392,7 +372,8 @@ class AsynGreedySolver(Solver):
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
|
||||
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
|
||||
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
|
||||
)
|
||||
|
||||
region_to_region_map.clear()
|
||||
|
||||
@@ -452,7 +433,6 @@ class AsynGreedySolver(Solver):
|
||||
|
||||
peak_mem_saving = 0
|
||||
while len(self.region_to_region_map) and peak_mem_saving <= 0:
|
||||
|
||||
max_profit = (0,)
|
||||
best_ts = None
|
||||
undo_host_region = None
|
||||
@@ -464,8 +444,7 @@ class AsynGreedySolver(Solver):
|
||||
assert offload_region.need_offload
|
||||
assert not offload_region.is_syn
|
||||
|
||||
ts, profit = self._try_convert_to_syn_upload(host_region,
|
||||
offload_region)
|
||||
ts, profit = self._try_convert_to_syn_upload(host_region, offload_region)
|
||||
|
||||
if self._compare_profit(profit, max_profit):
|
||||
undo_host_region = host_region
|
||||
@@ -474,7 +453,7 @@ class AsynGreedySolver(Solver):
|
||||
best_ts = ts
|
||||
|
||||
if best_ts is None:
|
||||
raise NotImplementedError('repair error!')
|
||||
raise NotImplementedError("repair error!")
|
||||
|
||||
assert not undo_offload_region.is_syn
|
||||
undo_offload_region.is_syn = True
|
||||
@@ -500,17 +479,13 @@ class AsynGreedySolver(Solver):
|
||||
ts.execute()
|
||||
|
||||
extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0)
|
||||
profit = self._compute_offload_profit(
|
||||
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
|
||||
profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
|
||||
|
||||
return ts, profit
|
||||
|
||||
|
||||
class SolverFactory:
|
||||
solvers: Dict[str, Type[Solver]] = {
|
||||
'syn': SynGreedySolver,
|
||||
'asyn': AsynGreedySolver
|
||||
}
|
||||
solvers: Dict[str, Type[Solver]] = {"syn": SynGreedySolver, "asyn": AsynGreedySolver}
|
||||
|
||||
@staticmethod
|
||||
def create(solver_name: str) -> Type[Solver]:
|
||||
|
Reference in New Issue
Block a user