mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from typing import List, Any, Dict, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from .region import Region
|
||||
from .solver import SolverFactory
|
||||
from .training_simulator import TrainingSimulator
|
||||
from .region import Region
|
||||
from .util import NodeInfo
|
||||
|
||||
|
||||
@@ -19,14 +20,9 @@ class RegionManager:
|
||||
cnode (List[str], optional): Common node List, should be the subset of input.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
solver_name: str = 'asyn',
|
||||
memory_budget: float = -1.0,
|
||||
cnode: List[str] = None):
|
||||
|
||||
def __init__(self, graph: Graph, solver_name: str = "asyn", memory_budget: float = -1.0, cnode: List[str] = None):
|
||||
self.graph = graph
|
||||
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
||||
assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
|
||||
self.root_module = self.graph.owning_module
|
||||
self.nodes = list(graph.nodes)
|
||||
self.cnode = cnode
|
||||
@@ -39,7 +35,7 @@ class RegionManager:
|
||||
self.memory_budget = memory_budget
|
||||
|
||||
self.solver_name = solver_name
|
||||
self.require_pool: bool = solver_name == 'asyn'
|
||||
self.require_pool: bool = solver_name == "asyn"
|
||||
|
||||
self.reg_to_block: Dict[int, int] = dict()
|
||||
|
||||
@@ -61,22 +57,19 @@ class RegionManager:
|
||||
self._post_process(solver.best_ts)
|
||||
|
||||
def _pre_process(self):
|
||||
|
||||
init_region_list = self._linearize_graph()
|
||||
|
||||
if len(self.shared_region_pairs) > 1:
|
||||
raise NotImplementedError(
|
||||
'The current version only considers at most one pair of parameter sharing.')
|
||||
raise NotImplementedError("The current version only considers at most one pair of parameter sharing.")
|
||||
|
||||
elif len(self.shared_region_pairs) == 1:
|
||||
shared_regs = self.shared_region_pairs[0]
|
||||
assert shared_regs[0].shared_rid == shared_regs[1].r_id \
|
||||
and shared_regs[1].shared_rid == shared_regs[0].r_id
|
||||
assert shared_regs[0].shared_rid == shared_regs[1].r_id and shared_regs[1].shared_rid == shared_regs[0].r_id
|
||||
fst_id = shared_regs[0].r_id
|
||||
lst_id = shared_regs[1].r_id
|
||||
regs_left_out = init_region_list[:fst_id + 1]
|
||||
regs_left_out = init_region_list[: fst_id + 1]
|
||||
regs_right_out = init_region_list[lst_id:]
|
||||
hold_regs = init_region_list[fst_id + 1:lst_id]
|
||||
hold_regs = init_region_list[fst_id + 1 : lst_id]
|
||||
else:
|
||||
regs_left_out = []
|
||||
regs_right_out = []
|
||||
@@ -122,12 +115,9 @@ class RegionManager:
|
||||
it may not find a suitable region placement strategy for the given execution flow.
|
||||
"""
|
||||
|
||||
reg_flow = torch.cat(
|
||||
[ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
|
||||
mem_block_num = torch.max(
|
||||
torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
|
||||
coexist_matrix = torch.logical_or(
|
||||
ts.fwd_reg_flow, ts.bwd_reg_flow)
|
||||
reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
|
||||
mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
|
||||
coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow)
|
||||
|
||||
block_to_regs = {}
|
||||
for block_idx in range(mem_block_num):
|
||||
@@ -135,8 +125,7 @@ class RegionManager:
|
||||
for reg in self.region_list:
|
||||
if reg.r_id in self.rid_in_pool:
|
||||
cur_reg_appears = coexist_matrix[:, reg.r_id]
|
||||
cur_reg_coexists = torch.sum(
|
||||
coexist_matrix[cur_reg_appears], dim=0).bool()
|
||||
cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool()
|
||||
for block_idx in range(mem_block_num):
|
||||
if not any(cur_reg_coexists[block_to_regs[block_idx]]):
|
||||
block_to_regs[block_idx].append(reg.r_id)
|
||||
@@ -145,9 +134,12 @@ class RegionManager:
|
||||
|
||||
if reg.r_id not in self.reg_to_block:
|
||||
raise NotImplementedError(
|
||||
f'can not find a block from the memory pool to store parameters of the region')
|
||||
self.memory_pool = torch.chunk(torch.zeros(int(
|
||||
mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num))
|
||||
f"can not find a block from the memory pool to store parameters of the region"
|
||||
)
|
||||
self.memory_pool = torch.chunk(
|
||||
torch.zeros(int(mem_block_num * self.mem_block_size / 2), dtype=torch.half, device="cuda"),
|
||||
chunks=int(mem_block_num),
|
||||
)
|
||||
|
||||
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
|
||||
"""
|
||||
@@ -178,10 +170,9 @@ class RegionManager:
|
||||
|
||||
return region_list
|
||||
|
||||
def _search_block_size(self,
|
||||
region_list: List[Region],
|
||||
search_interval_byte: int = 1024,
|
||||
search_range_byte: int = 128 * 1024 ** 2) -> int:
|
||||
def _search_block_size(
|
||||
self, region_list: List[Region], search_interval_byte: int = 1024, search_range_byte: int = 128 * 1024**2
|
||||
) -> int:
|
||||
"""
|
||||
Search for a suitable memory block size.
|
||||
|
||||
@@ -208,11 +199,10 @@ class RegionManager:
|
||||
acc_wasted += blk_size - left
|
||||
return acc_wasted
|
||||
|
||||
param_size_list = [
|
||||
region.param_size for region in region_list if region.r_id == region.shared_rid]
|
||||
param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid]
|
||||
|
||||
start_size = max(param_size_list)
|
||||
min_mem_waste = float('+inf')
|
||||
min_mem_waste = float("+inf")
|
||||
best_block_size = start_size
|
||||
|
||||
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
|
||||
@@ -229,7 +219,7 @@ class RegionManager:
|
||||
Initialize region data, which maps the parameters in the region to a contiguous memory space.
|
||||
"""
|
||||
|
||||
self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32)
|
||||
self.temp_fp32_data = torch.zeros(self.max_param_num, device="cuda", dtype=torch.float32)
|
||||
|
||||
for region in self.region_list:
|
||||
pre_alloc_tensor = None
|
||||
@@ -244,8 +234,7 @@ class RegionManager:
|
||||
region.fp16_data = shared_region.fp16_data
|
||||
region.fp32_data = shared_region.fp32_data
|
||||
region.param_to_range = shared_region.param_to_range
|
||||
region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach(
|
||||
)
|
||||
region.temp_fp32_data = self.temp_fp32_data[: region.param_num].detach()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -259,13 +248,14 @@ class RegionManager:
|
||||
former_reg, latter_reg = self.shared_region_pairs[0]
|
||||
assert latter_reg.param_num >= former_reg.param_num
|
||||
embedding_node = former_reg.nodes[-1]
|
||||
assert embedding_node.op == 'call_module' and isinstance(
|
||||
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding)
|
||||
assert embedding_node.op == "call_module" and isinstance(
|
||||
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding
|
||||
)
|
||||
if latter_reg.param_num > former_reg.param_num:
|
||||
for idx, n in enumerate(latter_reg.nodes):
|
||||
if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target),
|
||||
torch.nn.Linear)) or \
|
||||
(n.op == 'call_function' and n.target is torch.nn.functional.linear):
|
||||
if (
|
||||
n.op == "call_module" and isinstance(self.root_module.get_submodule(n.target), torch.nn.Linear)
|
||||
) or (n.op == "call_function" and n.target is torch.nn.functional.linear):
|
||||
cut_node_idx = idx + 1
|
||||
break
|
||||
assert len(latter_reg.fp16_params) == 2
|
||||
@@ -273,7 +263,7 @@ class RegionManager:
|
||||
for p in new_reg.fp16_params:
|
||||
self.param_region_map[p] = new_reg
|
||||
self.region_list.insert(new_reg.r_id, new_reg)
|
||||
for reg in self.region_list[new_reg.r_id + 1:]:
|
||||
for reg in self.region_list[new_reg.r_id + 1 :]:
|
||||
reg.r_id += 1
|
||||
latter_reg.shared_rid = former_reg.r_id
|
||||
former_reg.shared_rid = latter_reg.r_id
|
||||
@@ -344,8 +334,8 @@ class RegionManager:
|
||||
target = n.target
|
||||
submod = self.root_module.get_submodule(target)
|
||||
if (
|
||||
len(list(submod.named_parameters(recurse=False))) != 0
|
||||
or len(list(submod.named_buffers(recurse=False))) != 0
|
||||
len(list(submod.named_parameters(recurse=False))) != 0
|
||||
or len(list(submod.named_buffers(recurse=False))) != 0
|
||||
):
|
||||
label = True
|
||||
|
||||
@@ -362,14 +352,12 @@ class RegionManager:
|
||||
"""
|
||||
|
||||
def _is_inplace(n: Node):
|
||||
"""Get the inplace argument from ``torch.fx.Node``
|
||||
"""
|
||||
"""Get the inplace argument from ``torch.fx.Node``"""
|
||||
inplace = False
|
||||
if n.op == "call_function":
|
||||
inplace = n.kwargs.get("inplace", False)
|
||||
elif n.op == "call_module":
|
||||
inplace = getattr(n.graph.owning_module.get_submodule(
|
||||
n.target), "inplace", False)
|
||||
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
||||
return inplace
|
||||
|
||||
label = False
|
||||
@@ -378,28 +366,30 @@ class RegionManager:
|
||||
target = n.target
|
||||
submod = self.root_module.get_submodule(target)
|
||||
if (
|
||||
len(list(submod.named_parameters(recurse=False))) != 0
|
||||
or len(list(submod.named_buffers(recurse=False))) != 0
|
||||
len(list(submod.named_parameters(recurse=False))) != 0
|
||||
or len(list(submod.named_buffers(recurse=False))) != 0
|
||||
):
|
||||
label = True
|
||||
|
||||
elif n.op == "call_function":
|
||||
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
|
||||
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes))
|
||||
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)
|
||||
)
|
||||
|
||||
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
|
||||
|
||||
def _exception_node_handling():
|
||||
# TODO meta info prop bug
|
||||
if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2:
|
||||
n.meta['fwd_out'] = []
|
||||
if n.name.__contains__("transpose") and n.meta["fwd_out"][0].dim() <= 2:
|
||||
n.meta["fwd_out"] = []
|
||||
|
||||
# make sure that item in cnode is valid
|
||||
if self.cnode:
|
||||
for name in self.cnode:
|
||||
try:
|
||||
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
|
||||
f"Common node {name} is not an input of the model."
|
||||
assert (
|
||||
next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
|
||||
), f"Common node {name} is not an input of the model."
|
||||
except StopIteration:
|
||||
raise ValueError(f"Common node name {name} not in graph.")
|
||||
else:
|
||||
@@ -428,8 +418,8 @@ class RegionManager:
|
||||
ns = []
|
||||
border_n_idx = region.nodes.index(act_n)
|
||||
if border_n_idx < len(region.nodes):
|
||||
ns = region.nodes[border_n_idx + 1:]
|
||||
region.nodes = region.nodes[:border_n_idx + 1]
|
||||
ns = region.nodes[border_n_idx + 1 :]
|
||||
region.nodes = region.nodes[: border_n_idx + 1]
|
||||
region_list.append(region)
|
||||
region_id += 1
|
||||
region = Region(r_id=region_id)
|
||||
@@ -448,19 +438,21 @@ class RegionManager:
|
||||
region = Region(r_id=region_id)
|
||||
|
||||
# propagate common node attr if possible
|
||||
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
|
||||
]) or _is_cop(n.target):
|
||||
if len(n.all_input_nodes) == len(
|
||||
[node for node in n.all_input_nodes if node.name in self.cnode]
|
||||
) or _is_cop(n.target):
|
||||
self.cnode.append(n.name)
|
||||
else:
|
||||
deps[n] = len(
|
||||
[user for user in n.users if user.op != "output"])
|
||||
deps[n] = len([user for user in n.users if user.op != "output"])
|
||||
|
||||
# propagate param node attr if possible
|
||||
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops
|
||||
]) or n.op == "get_attr":
|
||||
if (
|
||||
len(n.all_input_nodes)
|
||||
== len([node for node in n.all_input_nodes if node.name in self.only_param_ops])
|
||||
or n.op == "get_attr"
|
||||
):
|
||||
self.only_param_ops.append(n.name)
|
||||
param_op_deps[n] = len(
|
||||
[user for user in n.users if user.op != "output"])
|
||||
param_op_deps[n] = len([user for user in n.users if user.op != "output"])
|
||||
|
||||
# record last activation node
|
||||
if _is_act(n._meta_data):
|
||||
@@ -472,19 +464,16 @@ class RegionManager:
|
||||
return region_list
|
||||
|
||||
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
|
||||
|
||||
cur_n.node_info = NodeInfo(node_id)
|
||||
|
||||
if cur_n.op == 'call_module':
|
||||
if cur_n.op == "call_module":
|
||||
target = cur_n.target
|
||||
submod = self.root_module.get_submodule(target)
|
||||
for p in list(submod.parameters(recurse=False)):
|
||||
|
||||
if p in self.param_region_map:
|
||||
cur_reg.shared_rid = self.param_region_map[p].r_id
|
||||
self.param_region_map[p].shared_rid = cur_reg.r_id
|
||||
self.shared_region_pairs.append(
|
||||
(self.param_region_map[p], cur_reg))
|
||||
self.shared_region_pairs.append((self.param_region_map[p], cur_reg))
|
||||
else:
|
||||
self.param_region_map[p] = cur_reg
|
||||
|
||||
@@ -499,12 +488,10 @@ class RegionManager:
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
|
||||
if isinstance(attr_itr, torch.nn.Parameter):
|
||||
|
||||
if attr_itr in self.param_region_map:
|
||||
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
|
||||
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
|
||||
self.shared_region_pairs.append(
|
||||
(self.param_region_map[attr_itr], cur_reg))
|
||||
self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg))
|
||||
else:
|
||||
self.param_region_map[attr_itr] = cur_reg
|
||||
|
||||
|
Reference in New Issue
Block a user