[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,8 +1,6 @@
import numpy as np
import torch
import tqdm
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
@@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
accumulate_bwd_flop = 0
block_nodes = []
for node in gm.graph.nodes:
if 'block_split' in node.name:
if "block_split" in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
accumulate_bwd_flop += node.bwd_flop
if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
with gm.graph.inserting_after(node):
block_node = gm.graph.create_node('call_function', block_split)
setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
block_node = gm.graph.create_node("call_function", block_split)
setattr(block_node, "fwd_flop", accumulate_fwd_flop)
setattr(block_node, "bwd_flop", accumulate_bwd_flop)
accumulate_fwd_flop = 0
accumulate_bwd_flop = 0
block_nodes.append(block_node)
@@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
def remove_blocks(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if (node.op, node.target) == ('call_function', block_split):
if (node.op, node.target) == ("call_function", block_split):
gm.graph.erase_node(node)
@@ -55,8 +53,8 @@ def get_compute_costs(node_list):
num_nodes = len(node_list)
all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
for start in tqdm.tqdm(range(num_nodes), desc="start pos", position=0):
for end in tqdm.tqdm(range(start, num_nodes), desc="end pos", position=1, leave=False):
selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
all_compute_cost[start, end] = sum(selected_flops)
@@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost
# record start node index for next stage in this partition
f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
f[0, num_nodes] = 0
for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
for s in tqdm.tqdm(
range(1, num_stages + 1), desc="stage", position=2, leave=False
): # pylint: disable=too-many-nested-blocks
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc="start node", position=3, leave=False):
for k in tqdm.tqdm(range(num_nodes, i, -1), desc="mid node", position=4, leave=False):
stage_cost = compute_costs[i, k - 1]
new_cost = f[s - 1, k] + stage_cost
if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
if stage_cost <= max_compute_cost and new_cost < f[s, i]:
f[s, i] = new_cost
f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
f_argmin[s, i] = k
@@ -113,7 +113,7 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
best_cost = np.inf
best_solution = None
last_max_compute_cost = 0.0
gap = 1e6 # temporary magic number, unit: flops
gap = 1e6 # temporary magic number, unit: flops
for max_compute_cost in tqdm.tqdm(max_compute_costs):
# Pruning to reduce search space.
@@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
if max_compute_cost - last_max_compute_cost < gap:
continue
cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
max_compute_cost)
cost, solution = do_dp_split_gpipe_impl(
len(node_list), num_stages, num_microbatches, compute_costs, max_compute_cost
)
if cost < best_cost:
best_cost = cost
@@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
# split_mode:
# 'node': fx_node
# 'block': many fx_nodes construct a block
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
assert mode in ['node', 'block']
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode="block", block_limit=0.01):
assert mode in ["node", "block"]
# nodes or blocks will be used in partition.
node_list = []
if mode == 'node':
if mode == "node":
for node in gm.graph.nodes:
node_list.append(node)
elif mode == 'block':
elif mode == "block":
node_list = construct_blocks(gm, limit=block_limit)
else:
pass
@@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches
best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
for (_, next_start_node) in best_solution:
for _, next_start_node in best_solution:
if pp_size <= 1:
break
node = node_list[next_start_node]
with gm.graph.inserting_before(node):
split_node = gm.graph.create_node('call_function', pipe_split)
split_node = gm.graph.create_node("call_function", pipe_split)
pp_size -= 1
# remove block node if possible
if mode == 'block':
if mode == "block":
remove_blocks(gm)
gm.recompile()
@@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
if 'tensor_meta' not in check_node.meta:
if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_fwd_flop = 0
@@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
if 'pipe_split' in node.name:
if "pipe_split" in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
if accumulate_fwd_flop >= partition_flop:
@@ -199,7 +200,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
partition_flop = total_fwd_flop // pp_size
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
if accumulate_num_node >= avg_num_node:
accumulate_num_node = 0
pp_size -= 1
if node.next.op == 'output':
if node.next.op == "output":
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
# If the next node is output node, we will insert split annotation before
# node to make sure there is at least one node in last partition.
if node.next.op == 'output':
if node.next.op == "output":
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
if pp_size > 1:
node_counter = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
if node.op == 'placeholder':
if node.op == "placeholder":
continue
elif node_counter == 0:
node_counter += 1
@@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
node_counter = 0
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
# To use balanced_split_pass_v2, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
if 'tensor_meta' not in check_node.meta:
if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_element_size = 0
@@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
if 'pipe_split' in node.name:
if "pipe_split" in node.name:
continue
accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size:
@@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
partition_size = total_element_size // pp_size
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
accumulate_layer_amount = 0
pp_size -= 1
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
def split_callback(n: torch.fx.Node):
nonlocal part_idx
if (n.op, n.target) == ('call_function', pipe_split):
if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1
return part_idx
@@ -355,7 +356,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
if (node.op, node.target) == ('call_function', pipe_split):
if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)

View File

@@ -1,5 +1,5 @@
from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.fx
@@ -85,10 +85,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
self._is_proped = True
result, meta_info = super().run_node(n)
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
n.meta['type'] = type(result)
setattr(n, "node_size", n.meta.get("fwd_mem_tmp", 0) + n.meta.get("fwd_mem_out", 0))
n.meta["type"] = type(result)
# retain the autograd graph
for param in self.module.parameters():
@@ -98,7 +98,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
@@ -119,7 +119,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
@@ -138,7 +138,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
@@ -157,7 +157,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_function(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
@@ -175,7 +175,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_method(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
@@ -197,7 +197,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_module(submod, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
@@ -228,7 +228,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
"""
return self.run(*args)
def summary(self, unit: str = 'MB') -> str:
def summary(self, unit: str = "MB") -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
@@ -238,9 +238,11 @@ class ConcreteInfoProp(torch.fx.Interpreter):
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
print(
"`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
@@ -249,10 +251,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
def mem_repr(mem: int) -> str:
unit_divisor_map = {
'kb': 1024,
'mb': 1024**2,
'gb': 1024**3,
'tb': 1024**4,
"kb": 1024,
"mb": 1024**2,
"gb": 1024**3,
"tb": 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
@@ -261,30 +263,32 @@ class ConcreteInfoProp(torch.fx.Interpreter):
for node in self.module.graph.nodes:
node: Node
node_summaries.append([
node.op,
str(node),
time_repr(node.meta['fwd_time']),
time_repr(node.meta['bwd_time']),
node.meta['save_fwd_in'],
mem_repr(node.meta['fwd_mem_out']),
mem_repr(node.meta['fwd_mem_tmp']),
mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']),
])
node_summaries.append(
[
node.op,
str(node),
time_repr(node.meta["fwd_time"]),
time_repr(node.meta["bwd_time"]),
node.meta["save_fwd_in"],
mem_repr(node.meta["fwd_mem_out"]),
mem_repr(node.meta["fwd_mem_tmp"]),
mem_repr(node.meta["bwd_mem_out"]),
mem_repr(node.meta["bwd_mem_tmp"]),
]
)
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Forward time',
'Backward time',
'SAVE_FWD_IN',
'FWD_OUT',
'FWD_TMP',
'BWD_OUT',
'BWD_TMP',
"Op type",
"Op",
"Forward time",
"Backward time",
"SAVE_FWD_IN",
"FWD_OUT",
"FWD_TMP",
"BWD_OUT",
"BWD_TMP",
]
return tabulate(node_summaries, headers=headers, stralign='right')
return tabulate(node_summaries, headers=headers, stralign="right")

View File

@@ -1,14 +1,11 @@
import torch
from typing import List
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
import builtins
import operator
from copy import deepcopy
from typing import List
import torch
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
def apply(*args, **kwargs):
@@ -24,16 +21,16 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi
origin_node_sharding_spec_dict = {}
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
setattr(node, 'best_strategy', strategies_vector[strategy_index])
setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec)
setattr(node, "best_strategy", strategies_vector[strategy_index])
setattr(node, "sharding_spec", strategies_vector[strategy_index].output_sharding_spec)
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec
# apply the sharding spec of parameters
for node in nodes:
if node.op == 'call_module':
if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {})
setattr(target_module.weight, 'sharding_spec', origin_sharding_spec)
setattr(target_module.weight, "sharding_spec", origin_sharding_spec)
target_weight_sharding_spec = node.best_strategy.input_shardings[1]
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3))
apply(target_module.weight, target_weight_sharding_spec)
@@ -51,10 +48,10 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi
# add above dicts into graph
for node in nodes:
if node.op != 'placeholder':
if node.op != "placeholder":
with mod_graph.inserting_before(node):
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
break
return sharding_spec_convert_dict, origin_node_sharding_spec_dict
@@ -70,13 +67,13 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
node_to_index_dict = {}
index = 0
for node in nodes:
if node.target == 'sharding_spec_convert_dict':
if node.target == "sharding_spec_convert_dict":
input_dict_node = node
continue
if node.target == 'origin_node_sharding_spec_dict':
if node.target == "origin_node_sharding_spec_dict":
origin_dict_node = node
continue
if not hasattr(node, 'best_strategy'):
if not hasattr(node, "best_strategy"):
continue
node_to_index_dict[node] = index
index += 1
@@ -84,28 +81,28 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
# add shape consistency apply function into graph
for node in nodes:
if not hasattr(node, 'best_strategy'):
if not hasattr(node, "best_strategy"):
continue
with mod_graph.inserting_after(node):
origin_spec_node = mod_graph.create_node('call_function',
operator.getitem,
args=(origin_dict_node, node_to_index_dict[node]))
origin_spec_node = mod_graph.create_node(
"call_function", operator.getitem, args=(origin_dict_node, node_to_index_dict[node])
)
with mod_graph.inserting_after(origin_spec_node):
set_sharding_spec_node = mod_graph.create_node('call_function',
builtins.setattr,
args=(node, 'sharding_spec', origin_spec_node))
set_sharding_spec_node = mod_graph.create_node(
"call_function", builtins.setattr, args=(node, "sharding_spec", origin_spec_node)
)
for user_node in node.strategies_vector.successor_nodes:
node_index = user_node.strategies_vector.predecessor_nodes.index(node)
with mod_graph.inserting_before(user_node):
input_specs_node = mod_graph.create_node('call_function',
operator.getitem,
args=(input_dict_node, node_to_index_dict[node]))
input_specs_node = mod_graph.create_node(
"call_function", operator.getitem, args=(input_dict_node, node_to_index_dict[node])
)
with mod_graph.inserting_before(user_node):
sharding_spec_node = mod_graph.create_node('call_function',
operator.getitem,
args=(input_specs_node, node_index))
sharding_spec_node = mod_graph.create_node(
"call_function", operator.getitem, args=(input_specs_node, node_index)
)
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node))
shape_consistency_node = mod_graph.create_node("call_function", apply, args=(node, sharding_spec_node))
return gm

View File

@@ -109,13 +109,13 @@ class MetaInfoProp(torch.fx.Interpreter):
return TensorMetadata(None, None, False, None, 0, False)
tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
n.meta["tensor_meta"] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0))
n.meta['type'] = type(result)
setattr(n, "node_size", activation_size(n.meta.get("fwd_out", 0)) + activation_size(n.meta.get("fwd_tmp", 0)))
setattr(n, "fwd_flop", n.meta.get("fwd_flop", 0))
setattr(n, "bwd_flop", n.meta.get("bwd_flop", 0))
n.meta["type"] = type(result)
# retain the autograd graph
for param in self.module.parameters():
@@ -125,7 +125,7 @@ class MetaInfoProp(torch.fx.Interpreter):
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
@@ -146,7 +146,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
@@ -165,7 +165,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
@@ -184,7 +184,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return profile_function(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
@@ -202,7 +202,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return profile_method(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
@@ -224,7 +224,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return profile_module(submod)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
@@ -240,7 +240,7 @@ class MetaInfoProp(torch.fx.Interpreter):
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
if hasattr(args[0], '_tensor'):
if hasattr(args[0], "_tensor"):
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
return args[0], GraphInfo(save_fwd_in=True)
@@ -257,7 +257,7 @@ class MetaInfoProp(torch.fx.Interpreter):
"""
return super().run(*args)
def summary(self, unit: str = 'MB') -> str:
def summary(self, unit: str = "MB") -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
@@ -267,9 +267,11 @@ class MetaInfoProp(torch.fx.Interpreter):
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
print(
"`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
@@ -278,10 +280,10 @@ class MetaInfoProp(torch.fx.Interpreter):
def mem_repr(mem: int) -> str:
unit_divisor_map = {
'kb': 1024,
'mb': 1024**2,
'gb': 1024**3,
'tb': 1024**4,
"kb": 1024,
"mb": 1024**2,
"gb": 1024**3,
"tb": 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
@@ -292,35 +294,37 @@ class MetaInfoProp(torch.fx.Interpreter):
for node in self.module.graph.nodes:
node: Node
accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node)
node_summaries.append([
node.op,
str(node),
flops_repr(node.meta['fwd_flop']),
flops_repr(node.meta['bwd_flop']),
mem_repr(accumulate_size),
mem_repr(calculate_fwd_in(node)),
mem_repr(calculate_fwd_out(node)),
mem_repr(calculate_fwd_tmp(node)),
mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']),
])
node_summaries.append(
[
node.op,
str(node),
flops_repr(node.meta["fwd_flop"]),
flops_repr(node.meta["bwd_flop"]),
mem_repr(accumulate_size),
mem_repr(calculate_fwd_in(node)),
mem_repr(calculate_fwd_out(node)),
mem_repr(calculate_fwd_tmp(node)),
mem_repr(node.meta["bwd_mem_out"]),
mem_repr(node.meta["bwd_mem_tmp"]),
]
)
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Forward FLOPs',
'Backward FLOPs',
'Accumulated Memory',
'FWD_IN',
'FWD_OUT',
'FWD_TMP',
'BWD_OUT',
'BWD_TMP',
"Op type",
"Op",
"Forward FLOPs",
"Backward FLOPs",
"Accumulated Memory",
"FWD_IN",
"FWD_OUT",
"FWD_TMP",
"BWD_OUT",
"BWD_TMP",
]
return tabulate(node_summaries, headers=headers, stralign='right')
return tabulate(node_summaries, headers=headers, stralign="right")
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
@@ -344,15 +348,16 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit:
Returns:
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
"""
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
interp = MetaInfoProp(gm.to(device))
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
interp.propagate(*args, **kwargs)
if verbose:
interp.summary(unit)
gm.to('cpu')
gm.to("cpu")
del interp
return gm

View File

@@ -5,7 +5,6 @@ import torch
from packaging import version
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split
from colossalai.fx.passes.meta_info_prop import TensorMetadata
@@ -13,9 +12,9 @@ from colossalai.fx.passes.split_module import Partition
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
'''
"""
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
'''
"""
mod_graph = gm.graph
valid_children_size = 0
valid_children = []
@@ -39,40 +38,40 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, parti
part_index += 1
pp_size -= 1
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
'''
"""
This pass will be used in gpt2 test, only a part of changes may be added into
split_with_split_nodes_pass, and it will be deprecated in future.
'''
"""
part_idx = 0
def eliminate_unused_placeholders(gm):
for node in gm.graph.nodes:
if node.op == 'placeholder':
if node.op == "placeholder":
if not len(node.users):
gm.graph.erase_node(node)
gm.recompile()
return gm
def refill_outputs_and_placeholders(gm, next_partition_placeholders):
'''
"""
This method is used to eliminate the outputs in previous partition which is unused in next partition.
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
to partition 1 and partition 2. However, in single direction linked list, we need to do so.
'''
"""
output_type = None
output_args = []
non_output_list = []
new_placeholder_list = []
for node in gm.graph.nodes:
if node.op == 'output':
if node.op == "output":
if isinstance(node.args[0], (tuple, list)):
output_type = node.args[0].__class__
output_args.extend([n.name for n in node.args[0]])
@@ -114,7 +113,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
continue
for node in gm.graph.nodes:
if node.op == 'placeholder':
if node.op == "placeholder":
new_placeholder_list.append(node.name)
if output_type is not None:
gm.graph.output(output_type(output_args))
@@ -125,7 +124,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
def split_callback(n: torch.fx.Node):
nonlocal part_idx
if (n.op, n.target) == ('call_function', pipe_split):
if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1
return part_idx
@@ -134,7 +133,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
if (node.op, node.target) == ('call_function', pipe_split):
if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)
@@ -200,13 +199,12 @@ def split_module_for_gpt2_test(
_gen_all_ancestors_set(node)
for n in list(all_ancestors):
if n.op != 'placeholder' and n._fx_partition > partition_name:
if n.op != "placeholder" and n._fx_partition > partition_name:
n._fx_partition = partition_name
def record_cross_partition_use(def_node: torch.fx.node.Node,
use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None)
def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
# if 'tensor_meta' in def_node.meta:
# if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
@@ -237,7 +235,7 @@ def split_module_for_gpt2_test(
if node.op in ["placeholder"]:
continue
if node.op == 'output':
if node.op == "output":
# partition_name = str(split_callback(node))
# def _set_output_args_partition(n, partition_name):
# n._fx_partition = partition_name
@@ -252,12 +250,12 @@ def split_module_for_gpt2_test(
partitions[partition_name] = partition = Partition(partition_name)
partition.node_names.append(node.name)
origin_partition_name = getattr(node, '_fx_partition', None)
origin_partition_name = getattr(node, "_fx_partition", None)
if origin_partition_name is None:
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
@@ -287,7 +285,7 @@ def split_module_for_gpt2_test(
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
if hasattr(node, '_fx_partition'):
if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
@@ -295,26 +293,24 @@ def split_module_for_gpt2_test(
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
if node.op not in ['call_module', 'get_attr']:
if node.op not in ["call_module", "get_attr"]:
target = node.target
else:
target_atoms = node.target.split('.')
target_atoms = node.target.split(".")
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
raise RuntimeError(f'Operator target {node.target} not found!')
raise RuntimeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
target = '_'.join(target_atoms)
target = "_".join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
new_node = partition.graph.create_node(op=node.op,
target=target,
args=gathered_args,
kwargs=gathered_kwargs,
name=node.name)
new_node = partition.graph.create_node(
op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs, name=node.name
)
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
@@ -323,14 +319,14 @@ def split_module_for_gpt2_test(
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
if version.parse(torch.__version__) < version.parse('1.11.0'):
if node.op == "placeholder":
if version.parse(torch.__version__) < version.parse("1.11.0"):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name] = base_mod_graph.placeholder(
node.name, type_expr=node.type, default_value=default_value
)
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
@@ -344,13 +340,14 @@ def split_module_for_gpt2_test(
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
submod_name = f'submod_{partition_name}'
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
partition.graph) # noqa: B950
submod_name = f"submod_{partition_name}"
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
partition.targets, partition.graph
) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
@@ -358,14 +355,14 @@ def split_module_for_gpt2_test(
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
if node.op == "output":
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)

View File

@@ -9,8 +9,19 @@ from colossalai.legacy.tensor.distspec import ShardSpec
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
torch.add,
operator.add,
torch.abs,
torch.cos,
torch.exp,
torch.mul,
operator.mul,
operator.floordiv,
operator.truediv,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
]
@@ -72,7 +83,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
# traverse the graph to look for consecutive linear layers
is_linear_module = False
if node.op == 'call_module':
if node.op == "call_module":
# look for the linear layer
module = node.graph.owning_module.get_submodule(node.target)
if isinstance(module, nn.Linear):
@@ -82,31 +93,31 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
# it means the first linear has been found and the current module
# is the second linear
# set the current linear module to be row-sharded
annotation_record['row'] = module
annotation_record["row"] = module
for shard_type, module in annotation_record.items():
# add row sharding spec
if shard_type == 'row':
if shard_type == "row":
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
comp_spec = ComputeSpec(ComputePattern.TP1D)
setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', dist_spec)
setattr(module.weight, 'comp_spec', comp_spec)
elif shard_type == 'col':
setattr(module.weight, "pg", process_group)
setattr(module.weight, "dist_spec", dist_spec)
setattr(module.weight, "comp_spec", comp_spec)
elif shard_type == "col":
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
weight_comp_spec.output_replicate = False
setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', weight_dist_spec)
setattr(module.weight, 'comp_spec', weight_comp_spec)
setattr(module.weight, "pg", process_group)
setattr(module.weight, "dist_spec", weight_dist_spec)
setattr(module.weight, "comp_spec", weight_comp_spec)
if module.bias is not None:
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
bias_comp_spec.output_replicate = False
setattr(module.bias, 'pg', process_group)
setattr(module.bias, 'dist_spec', bias_dist_spec)
setattr(module.bias, 'comp_spec', bias_comp_spec)
setattr(module.bias, "pg", process_group)
setattr(module.bias, "dist_spec", bias_dist_spec)
setattr(module.bias, "comp_spec", bias_comp_spec)
start_tracking = False
annotation_record.clear()
else:
@@ -114,16 +125,16 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
# it means the current layer is the first linear
# set the linear layer to be col-sharded
start_tracking = True
annotation_record['col'] = module
annotation_record["col"] = module
if start_tracking and not is_linear_module:
# check against the white list
# if non-element wise op is found, we reset the tracking
if node.op == 'call_module':
if node.op == "call_module":
module = node.graph.owning_module.get_submodule(node.target)
if module.__class__ not in ELEMENTWISE_MODULE_OP:
start_tracking = False
elif node.op == 'call_function' or node.op == 'call_method':
elif node.op == "call_function" or node.op == "call_method":
if node.target not in ELEMENTWISE_FUNC_OP:
start_tracking = False
elif len(node.users.keys()) > 1:

View File

@@ -25,12 +25,14 @@ class Partition:
self.targets: Dict[str, Any] = {}
def __repr__(self) -> str:
return f"name: {self.name},\n" \
f" nodes: {self.node_names},\n" \
f" inputs: {self.inputs},\n" \
f" outputs: {self.outputs},\n" \
f" partitions dependent on: {self.partitions_dependent_on},\n" \
return (
f"name: {self.name},\n"
f" nodes: {self.node_names},\n"
f" inputs: {self.inputs},\n"
f" outputs: {self.outputs},\n"
f" partitions dependent on: {self.partitions_dependent_on},\n"
f" partition dependents: {self.partition_dependents}"
)
# Creates subgraphs out of main graph
@@ -117,10 +119,9 @@ def split_module(
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {}
def record_cross_partition_use(def_node: torch.fx.node.Node,
use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None)
def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
@@ -134,7 +135,7 @@ def split_module(
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
@@ -161,7 +162,7 @@ def split_module(
if node.op in ["placeholder"]:
continue
if node.op == 'output':
if node.op == "output":
if merge_output:
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
else:
@@ -178,7 +179,7 @@ def split_module(
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
@@ -208,7 +209,7 @@ def split_module(
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
if hasattr(node, '_fx_partition'):
if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
@@ -216,25 +217,24 @@ def split_module(
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
if node.op not in ['call_module', 'get_attr']:
if node.op not in ["call_module", "get_attr"]:
target = node.target
else:
target_atoms = node.target.split('.')
target_atoms = node.target.split(".")
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
raise RuntimeError(f'Operator target {node.target} not found!')
raise RuntimeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
target = '_'.join(target_atoms)
target = "_".join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
new_node = partition.graph.create_node(op=node.op,
target=target,
args=gathered_args,
kwargs=gathered_kwargs)
new_node = partition.graph.create_node(
op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs
)
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
@@ -243,14 +243,14 @@ def split_module(
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
if version.parse(torch.__version__) < version.parse('1.11.0'):
if node.op == "placeholder":
if version.parse(torch.__version__) < version.parse("1.11.0"):
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.target,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name] = base_mod_graph.placeholder(
node.target, type_expr=node.type, default_value=default_value
)
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
@@ -264,13 +264,14 @@ def split_module(
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
submod_name = f'submod_{partition_name}'
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
partition.graph) # noqa: B950
submod_name = f"submod_{partition_name}"
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
partition.targets, partition.graph
) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
@@ -278,15 +279,15 @@ def split_module(
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
if node.op == "output":
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
for partition_name in sorted_partitions:
partition = partitions[partition_name]

View File

@@ -1,7 +1,9 @@
import torch
from typing import Dict
from torch.fx.node import Node, map_arg
import torch
from torch.fx.graph import Graph
from torch.fx.node import Node, map_arg
def get_comm_size(prev_partition, next_partition):
"""
@@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition):
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
for n in input_nodes:
if n.name in parent_node_names and n not in visited_nodes:
comm_size += n.meta['tensor_meta'].numel
comm_size += n.meta["tensor_meta"].numel
visited_nodes.add(n)
return comm_size
@@ -36,12 +38,12 @@ def get_leaf(graph: Graph):
"""
input_nodes: Dict[Node, None] = {}
for node in graph.nodes:
if node.op == 'output':
if node.op == "output":
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
placeholder_nodes = []
for node in input_nodes.keys():
if node.op == 'placeholder':
if node.op == "placeholder":
placeholder_nodes.append(node)
for node in placeholder_nodes:
input_nodes.pop(node)
@@ -60,13 +62,13 @@ def get_top(graph: Graph):
"""
top_node_list = set()
for node in graph.nodes:
if node.op == 'output':
if node.op == "output":
continue
is_top = False
def _get_top(node):
nonlocal is_top
if node.op == 'placeholder':
if node.op == "placeholder":
is_top = True
map_arg(node.args, lambda n: _get_top(n))
@@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node):
def get_all_consumers(graph: Graph, node: Node):
"""
Given a graph and a node of this graph, return all consumers of the node.
Returns:
List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``.
"""
@@ -120,7 +122,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
for node in gm.graph.nodes:
if hasattr(node, 'bfs_level'):
print(node.name, node.bfs_level)
Output:
graph():
%x : [#users=2] = placeholder[target=x]
@@ -148,7 +150,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
while nodes_to_process:
new_process_list = []
for node in nodes_to_process:
if node.op == 'output':
if node.op == "output":
continue
node.bfs_level = current_level
new_process_list.extend(get_all_consumers(graph, node))
@@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module:
torch.nn.Module: the module associated with the given node
"""
assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object'
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
assert (
node.graph.owning_module is not None
), "Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object"
assert node.op == "call_module", f"Expected node.op to be call_module, but found {node.op}"
module = node.graph.owning_module.get_submodule(node.target)
return module