[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,10 +1,11 @@
from typing import List, Any, Dict, Tuple
from typing import Any, Dict, List, Tuple
import torch
from torch.fx import Graph, Node
from .region import Region
from .solver import SolverFactory
from .training_simulator import TrainingSimulator
from .region import Region
from .util import NodeInfo
@@ -19,14 +20,9 @@ class RegionManager:
cnode (List[str], optional): Common node List, should be the subset of input.
"""
def __init__(self,
graph: Graph,
solver_name: str = 'asyn',
memory_budget: float = -1.0,
cnode: List[str] = None):
def __init__(self, graph: Graph, solver_name: str = "asyn", memory_budget: float = -1.0, cnode: List[str] = None):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.cnode = cnode
@@ -39,7 +35,7 @@ class RegionManager:
self.memory_budget = memory_budget
self.solver_name = solver_name
self.require_pool: bool = solver_name == 'asyn'
self.require_pool: bool = solver_name == "asyn"
self.reg_to_block: Dict[int, int] = dict()
@@ -61,22 +57,19 @@ class RegionManager:
self._post_process(solver.best_ts)
def _pre_process(self):
init_region_list = self._linearize_graph()
if len(self.shared_region_pairs) > 1:
raise NotImplementedError(
'The current version only considers at most one pair of parameter sharing.')
raise NotImplementedError("The current version only considers at most one pair of parameter sharing.")
elif len(self.shared_region_pairs) == 1:
shared_regs = self.shared_region_pairs[0]
assert shared_regs[0].shared_rid == shared_regs[1].r_id \
and shared_regs[1].shared_rid == shared_regs[0].r_id
assert shared_regs[0].shared_rid == shared_regs[1].r_id and shared_regs[1].shared_rid == shared_regs[0].r_id
fst_id = shared_regs[0].r_id
lst_id = shared_regs[1].r_id
regs_left_out = init_region_list[:fst_id + 1]
regs_left_out = init_region_list[: fst_id + 1]
regs_right_out = init_region_list[lst_id:]
hold_regs = init_region_list[fst_id + 1:lst_id]
hold_regs = init_region_list[fst_id + 1 : lst_id]
else:
regs_left_out = []
regs_right_out = []
@@ -122,12 +115,9 @@ class RegionManager:
it may not find a suitable region placement strategy for the given execution flow.
"""
reg_flow = torch.cat(
[ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
mem_block_num = torch.max(
torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
coexist_matrix = torch.logical_or(
ts.fwd_reg_flow, ts.bwd_reg_flow)
reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow)
block_to_regs = {}
for block_idx in range(mem_block_num):
@@ -135,8 +125,7 @@ class RegionManager:
for reg in self.region_list:
if reg.r_id in self.rid_in_pool:
cur_reg_appears = coexist_matrix[:, reg.r_id]
cur_reg_coexists = torch.sum(
coexist_matrix[cur_reg_appears], dim=0).bool()
cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool()
for block_idx in range(mem_block_num):
if not any(cur_reg_coexists[block_to_regs[block_idx]]):
block_to_regs[block_idx].append(reg.r_id)
@@ -145,9 +134,12 @@ class RegionManager:
if reg.r_id not in self.reg_to_block:
raise NotImplementedError(
f'can not find a block from the memory pool to store parameters of the region')
self.memory_pool = torch.chunk(torch.zeros(int(
mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num))
f"can not find a block from the memory pool to store parameters of the region"
)
self.memory_pool = torch.chunk(
torch.zeros(int(mem_block_num * self.mem_block_size / 2), dtype=torch.half, device="cuda"),
chunks=int(mem_block_num),
)
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
"""
@@ -178,10 +170,9 @@ class RegionManager:
return region_list
def _search_block_size(self,
region_list: List[Region],
search_interval_byte: int = 1024,
search_range_byte: int = 128 * 1024 ** 2) -> int:
def _search_block_size(
self, region_list: List[Region], search_interval_byte: int = 1024, search_range_byte: int = 128 * 1024**2
) -> int:
"""
Search for a suitable memory block size.
@@ -208,11 +199,10 @@ class RegionManager:
acc_wasted += blk_size - left
return acc_wasted
param_size_list = [
region.param_size for region in region_list if region.r_id == region.shared_rid]
param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid]
start_size = max(param_size_list)
min_mem_waste = float('+inf')
min_mem_waste = float("+inf")
best_block_size = start_size
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
@@ -229,7 +219,7 @@ class RegionManager:
Initialize region data, which maps the parameters in the region to a contiguous memory space.
"""
self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32)
self.temp_fp32_data = torch.zeros(self.max_param_num, device="cuda", dtype=torch.float32)
for region in self.region_list:
pre_alloc_tensor = None
@@ -244,8 +234,7 @@ class RegionManager:
region.fp16_data = shared_region.fp16_data
region.fp32_data = shared_region.fp32_data
region.param_to_range = shared_region.param_to_range
region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach(
)
region.temp_fp32_data = self.temp_fp32_data[: region.param_num].detach()
torch.cuda.empty_cache()
@@ -259,13 +248,14 @@ class RegionManager:
former_reg, latter_reg = self.shared_region_pairs[0]
assert latter_reg.param_num >= former_reg.param_num
embedding_node = former_reg.nodes[-1]
assert embedding_node.op == 'call_module' and isinstance(
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding)
assert embedding_node.op == "call_module" and isinstance(
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding
)
if latter_reg.param_num > former_reg.param_num:
for idx, n in enumerate(latter_reg.nodes):
if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target),
torch.nn.Linear)) or \
(n.op == 'call_function' and n.target is torch.nn.functional.linear):
if (
n.op == "call_module" and isinstance(self.root_module.get_submodule(n.target), torch.nn.Linear)
) or (n.op == "call_function" and n.target is torch.nn.functional.linear):
cut_node_idx = idx + 1
break
assert len(latter_reg.fp16_params) == 2
@@ -273,7 +263,7 @@ class RegionManager:
for p in new_reg.fp16_params:
self.param_region_map[p] = new_reg
self.region_list.insert(new_reg.r_id, new_reg)
for reg in self.region_list[new_reg.r_id + 1:]:
for reg in self.region_list[new_reg.r_id + 1 :]:
reg.r_id += 1
latter_reg.shared_rid = former_reg.r_id
former_reg.shared_rid = latter_reg.r_id
@@ -344,8 +334,8 @@ class RegionManager:
target = n.target
submod = self.root_module.get_submodule(target)
if (
len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0
len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
@@ -362,14 +352,12 @@ class RegionManager:
"""
def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node``
"""
"""Get the inplace argument from ``torch.fx.Node``"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(
n.target), "inplace", False)
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
label = False
@@ -378,28 +366,30 @@ class RegionManager:
target = n.target
submod = self.root_module.get_submodule(target)
if (
len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0
len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
elif n.op == "call_function":
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes))
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)
)
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
def _exception_node_handling():
# TODO meta info prop bug
if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2:
n.meta['fwd_out'] = []
if n.name.__contains__("transpose") and n.meta["fwd_out"][0].dim() <= 2:
n.meta["fwd_out"] = []
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
f"Common node {name} is not an input of the model."
assert (
next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
), f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
else:
@@ -428,8 +418,8 @@ class RegionManager:
ns = []
border_n_idx = region.nodes.index(act_n)
if border_n_idx < len(region.nodes):
ns = region.nodes[border_n_idx + 1:]
region.nodes = region.nodes[:border_n_idx + 1]
ns = region.nodes[border_n_idx + 1 :]
region.nodes = region.nodes[: border_n_idx + 1]
region_list.append(region)
region_id += 1
region = Region(r_id=region_id)
@@ -448,19 +438,21 @@ class RegionManager:
region = Region(r_id=region_id)
# propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
]) or _is_cop(n.target):
if len(n.all_input_nodes) == len(
[node for node in n.all_input_nodes if node.name in self.cnode]
) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len(
[user for user in n.users if user.op != "output"])
deps[n] = len([user for user in n.users if user.op != "output"])
# propagate param node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops
]) or n.op == "get_attr":
if (
len(n.all_input_nodes)
== len([node for node in n.all_input_nodes if node.name in self.only_param_ops])
or n.op == "get_attr"
):
self.only_param_ops.append(n.name)
param_op_deps[n] = len(
[user for user in n.users if user.op != "output"])
param_op_deps[n] = len([user for user in n.users if user.op != "output"])
# record last activation node
if _is_act(n._meta_data):
@@ -472,19 +464,16 @@ class RegionManager:
return region_list
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
cur_n.node_info = NodeInfo(node_id)
if cur_n.op == 'call_module':
if cur_n.op == "call_module":
target = cur_n.target
submod = self.root_module.get_submodule(target)
for p in list(submod.parameters(recurse=False)):
if p in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[p].r_id
self.param_region_map[p].shared_rid = cur_reg.r_id
self.shared_region_pairs.append(
(self.param_region_map[p], cur_reg))
self.shared_region_pairs.append((self.param_region_map[p], cur_reg))
else:
self.param_region_map[p] = cur_reg
@@ -499,12 +488,10 @@ class RegionManager:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.Parameter):
if attr_itr in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
self.shared_region_pairs.append(
(self.param_region_map[attr_itr], cur_reg))
self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg))
else:
self.param_region_map[attr_itr] = cur_reg