mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.fx.passes.split_module import split_module
|
||||
|
||||
@@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
|
||||
accumulate_bwd_flop = 0
|
||||
block_nodes = []
|
||||
for node in gm.graph.nodes:
|
||||
if 'block_split' in node.name:
|
||||
if "block_split" in node.name:
|
||||
continue
|
||||
accumulate_fwd_flop += node.fwd_flop
|
||||
accumulate_bwd_flop += node.bwd_flop
|
||||
if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
|
||||
with gm.graph.inserting_after(node):
|
||||
block_node = gm.graph.create_node('call_function', block_split)
|
||||
setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
|
||||
setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
|
||||
block_node = gm.graph.create_node("call_function", block_split)
|
||||
setattr(block_node, "fwd_flop", accumulate_fwd_flop)
|
||||
setattr(block_node, "bwd_flop", accumulate_bwd_flop)
|
||||
accumulate_fwd_flop = 0
|
||||
accumulate_bwd_flop = 0
|
||||
block_nodes.append(block_node)
|
||||
@@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
|
||||
|
||||
def remove_blocks(gm: torch.fx.GraphModule):
|
||||
for node in gm.graph.nodes:
|
||||
if (node.op, node.target) == ('call_function', block_split):
|
||||
if (node.op, node.target) == ("call_function", block_split):
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
|
||||
@@ -55,8 +53,8 @@ def get_compute_costs(node_list):
|
||||
num_nodes = len(node_list)
|
||||
all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
|
||||
|
||||
for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
|
||||
for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
|
||||
for start in tqdm.tqdm(range(num_nodes), desc="start pos", position=0):
|
||||
for end in tqdm.tqdm(range(start, num_nodes), desc="end pos", position=1, leave=False):
|
||||
selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
|
||||
all_compute_cost[start, end] = sum(selected_flops)
|
||||
|
||||
@@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost
|
||||
# record start node index for next stage in this partition
|
||||
f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
|
||||
f[0, num_nodes] = 0
|
||||
for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks
|
||||
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
|
||||
for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
|
||||
for s in tqdm.tqdm(
|
||||
range(1, num_stages + 1), desc="stage", position=2, leave=False
|
||||
): # pylint: disable=too-many-nested-blocks
|
||||
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc="start node", position=3, leave=False):
|
||||
for k in tqdm.tqdm(range(num_nodes, i, -1), desc="mid node", position=4, leave=False):
|
||||
stage_cost = compute_costs[i, k - 1]
|
||||
new_cost = f[s - 1, k] + stage_cost
|
||||
if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
|
||||
if stage_cost <= max_compute_cost and new_cost < f[s, i]:
|
||||
f[s, i] = new_cost
|
||||
f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
|
||||
f_argmin[s, i] = k
|
||||
@@ -113,7 +113,7 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
|
||||
best_cost = np.inf
|
||||
best_solution = None
|
||||
last_max_compute_cost = 0.0
|
||||
gap = 1e6 # temporary magic number, unit: flops
|
||||
gap = 1e6 # temporary magic number, unit: flops
|
||||
|
||||
for max_compute_cost in tqdm.tqdm(max_compute_costs):
|
||||
# Pruning to reduce search space.
|
||||
@@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
|
||||
if max_compute_cost - last_max_compute_cost < gap:
|
||||
continue
|
||||
|
||||
cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
|
||||
max_compute_cost)
|
||||
cost, solution = do_dp_split_gpipe_impl(
|
||||
len(node_list), num_stages, num_microbatches, compute_costs, max_compute_cost
|
||||
)
|
||||
|
||||
if cost < best_cost:
|
||||
best_cost = cost
|
||||
@@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
|
||||
# split_mode:
|
||||
# 'node': fx_node
|
||||
# 'block': many fx_nodes construct a block
|
||||
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
|
||||
assert mode in ['node', 'block']
|
||||
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode="block", block_limit=0.01):
|
||||
assert mode in ["node", "block"]
|
||||
|
||||
# nodes or blocks will be used in partition.
|
||||
node_list = []
|
||||
if mode == 'node':
|
||||
if mode == "node":
|
||||
for node in gm.graph.nodes:
|
||||
node_list.append(node)
|
||||
elif mode == 'block':
|
||||
elif mode == "block":
|
||||
node_list = construct_blocks(gm, limit=block_limit)
|
||||
else:
|
||||
pass
|
||||
@@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches
|
||||
|
||||
best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
|
||||
|
||||
for (_, next_start_node) in best_solution:
|
||||
for _, next_start_node in best_solution:
|
||||
if pp_size <= 1:
|
||||
break
|
||||
node = node_list[next_start_node]
|
||||
with gm.graph.inserting_before(node):
|
||||
split_node = gm.graph.create_node('call_function', pipe_split)
|
||||
split_node = gm.graph.create_node("call_function", pipe_split)
|
||||
pp_size -= 1
|
||||
|
||||
# remove block node if possible
|
||||
if mode == 'block':
|
||||
if mode == "block":
|
||||
remove_blocks(gm)
|
||||
|
||||
gm.recompile()
|
||||
@@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
|
||||
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
|
||||
check_node = list(mod_graph.nodes)[0]
|
||||
if 'tensor_meta' not in check_node.meta:
|
||||
if "tensor_meta" not in check_node.meta:
|
||||
return balanced_split_pass(gm, pp_size)
|
||||
|
||||
total_fwd_flop = 0
|
||||
@@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
for node in mod_graph.nodes:
|
||||
if pp_size <= 1:
|
||||
break
|
||||
if 'pipe_split' in node.name:
|
||||
if "pipe_split" in node.name:
|
||||
continue
|
||||
accumulate_fwd_flop += node.fwd_flop
|
||||
if accumulate_fwd_flop >= partition_flop:
|
||||
@@ -199,7 +200,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
pp_size -= 1
|
||||
partition_flop = total_fwd_flop // pp_size
|
||||
with mod_graph.inserting_after(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
split_node = mod_graph.create_node("call_function", pipe_split)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
@@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
if accumulate_num_node >= avg_num_node:
|
||||
accumulate_num_node = 0
|
||||
pp_size -= 1
|
||||
if node.next.op == 'output':
|
||||
if node.next.op == "output":
|
||||
with mod_graph.inserting_before(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
split_node = mod_graph.create_node("call_function", pipe_split)
|
||||
else:
|
||||
with mod_graph.inserting_after(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
split_node = mod_graph.create_node("call_function", pipe_split)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
@@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
pp_size -= 1
|
||||
# If the next node is output node, we will insert split annotation before
|
||||
# node to make sure there is at least one node in last partition.
|
||||
if node.next.op == 'output':
|
||||
if node.next.op == "output":
|
||||
with mod_graph.inserting_before(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
split_node = mod_graph.create_node("call_function", pipe_split)
|
||||
else:
|
||||
with mod_graph.inserting_after(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
split_node = mod_graph.create_node("call_function", pipe_split)
|
||||
if pp_size > 1:
|
||||
node_counter = 0
|
||||
for node in mod_graph.nodes:
|
||||
if pp_size <= 1:
|
||||
break
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
continue
|
||||
elif node_counter == 0:
|
||||
node_counter += 1
|
||||
@@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
pp_size -= 1
|
||||
node_counter = 0
|
||||
with mod_graph.inserting_before(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
split_node = mod_graph.create_node("call_function", pipe_split)
|
||||
|
||||
gm.recompile()
|
||||
return gm
|
||||
@@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
|
||||
# To use balanced_split_pass_v2, we need run meta_info_prop interpreter first.
|
||||
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
|
||||
check_node = list(mod_graph.nodes)[0]
|
||||
if 'tensor_meta' not in check_node.meta:
|
||||
if "tensor_meta" not in check_node.meta:
|
||||
return balanced_split_pass(gm, pp_size)
|
||||
|
||||
total_element_size = 0
|
||||
@@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
|
||||
for node in mod_graph.nodes:
|
||||
if pp_size <= 1:
|
||||
break
|
||||
if 'pipe_split' in node.name:
|
||||
if "pipe_split" in node.name:
|
||||
continue
|
||||
accumulate_node_size += node.node_size
|
||||
if accumulate_node_size >= partition_size:
|
||||
@@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
|
||||
pp_size -= 1
|
||||
partition_size = total_element_size // pp_size
|
||||
with mod_graph.inserting_after(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
split_node = mod_graph.create_node("call_function", pipe_split)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
@@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
accumulate_layer_amount = 0
|
||||
pp_size -= 1
|
||||
with mod_graph.inserting_after(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
split_node = mod_graph.create_node("call_function", pipe_split)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
@@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
|
||||
|
||||
def split_callback(n: torch.fx.Node):
|
||||
nonlocal part_idx
|
||||
if (n.op, n.target) == ('call_function', pipe_split):
|
||||
if (n.op, n.target) == ("call_function", pipe_split):
|
||||
part_idx += 1
|
||||
return part_idx
|
||||
|
||||
@@ -355,7 +356,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
|
||||
for name, submodule in split_mod.named_modules():
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
for node in submodule.graph.nodes:
|
||||
if (node.op, node.target) == ('call_function', pipe_split):
|
||||
if (node.op, node.target) == ("call_function", pipe_split):
|
||||
submodule.graph.erase_node(node)
|
||||
submodule.recompile()
|
||||
split_submodules.append(submodule)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
@@ -85,10 +85,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
self._is_proped = True
|
||||
result, meta_info = super().run_node(n)
|
||||
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
|
||||
n.meta['type'] = type(result)
|
||||
setattr(n, "node_size", n.meta.get("fwd_mem_tmp", 0) + n.meta.get("fwd_mem_out", 0))
|
||||
n.meta["type"] = type(result)
|
||||
|
||||
# retain the autograd graph
|
||||
for param in self.module.parameters():
|
||||
@@ -98,7 +98,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
|
||||
# Main Node running APIs
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``placeholder`` node. Note that this is stateful:
|
||||
``Interpreter`` maintains an internal iterator over
|
||||
@@ -119,7 +119,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
return super().placeholder(target, args, kwargs), GraphInfo()
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``get_attr`` node. Will retrieve an attribute
|
||||
value from the ``Module`` hierarchy of ``self.module``.
|
||||
@@ -138,7 +138,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
return super().get_attr(target, args, kwargs), GraphInfo()
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
@@ -157,7 +157,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
return profile_function(target, self.device)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
@@ -175,7 +175,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
return profile_method(target, self.device)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
@@ -197,7 +197,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
return profile_module(submod, self.device)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute an ``output`` node. This really just retrieves
|
||||
the value referenced by the ``output`` node and returns it.
|
||||
@@ -228,7 +228,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
"""
|
||||
return self.run(*args)
|
||||
|
||||
def summary(self, unit: str = 'MB') -> str:
|
||||
def summary(self, unit: str = "MB") -> str:
|
||||
"""
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
@@ -238,9 +238,11 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library.")
|
||||
print(
|
||||
"`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library."
|
||||
)
|
||||
|
||||
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
|
||||
|
||||
@@ -249,10 +251,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
|
||||
def mem_repr(mem: int) -> str:
|
||||
unit_divisor_map = {
|
||||
'kb': 1024,
|
||||
'mb': 1024**2,
|
||||
'gb': 1024**3,
|
||||
'tb': 1024**4,
|
||||
"kb": 1024,
|
||||
"mb": 1024**2,
|
||||
"gb": 1024**3,
|
||||
"tb": 1024**4,
|
||||
}
|
||||
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
|
||||
|
||||
@@ -261,30 +263,32 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
|
||||
for node in self.module.graph.nodes:
|
||||
node: Node
|
||||
node_summaries.append([
|
||||
node.op,
|
||||
str(node),
|
||||
time_repr(node.meta['fwd_time']),
|
||||
time_repr(node.meta['bwd_time']),
|
||||
node.meta['save_fwd_in'],
|
||||
mem_repr(node.meta['fwd_mem_out']),
|
||||
mem_repr(node.meta['fwd_mem_tmp']),
|
||||
mem_repr(node.meta['bwd_mem_out']),
|
||||
mem_repr(node.meta['bwd_mem_tmp']),
|
||||
])
|
||||
node_summaries.append(
|
||||
[
|
||||
node.op,
|
||||
str(node),
|
||||
time_repr(node.meta["fwd_time"]),
|
||||
time_repr(node.meta["bwd_time"]),
|
||||
node.meta["save_fwd_in"],
|
||||
mem_repr(node.meta["fwd_mem_out"]),
|
||||
mem_repr(node.meta["fwd_mem_tmp"]),
|
||||
mem_repr(node.meta["bwd_mem_out"]),
|
||||
mem_repr(node.meta["bwd_mem_tmp"]),
|
||||
]
|
||||
)
|
||||
|
||||
# Use the ``tabulate`` library to create a well-formatted table
|
||||
# presenting our summary information
|
||||
headers: List[str] = [
|
||||
'Op type',
|
||||
'Op',
|
||||
'Forward time',
|
||||
'Backward time',
|
||||
'SAVE_FWD_IN',
|
||||
'FWD_OUT',
|
||||
'FWD_TMP',
|
||||
'BWD_OUT',
|
||||
'BWD_TMP',
|
||||
"Op type",
|
||||
"Op",
|
||||
"Forward time",
|
||||
"Backward time",
|
||||
"SAVE_FWD_IN",
|
||||
"FWD_OUT",
|
||||
"FWD_TMP",
|
||||
"BWD_OUT",
|
||||
"BWD_TMP",
|
||||
]
|
||||
|
||||
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||
return tabulate(node_summaries, headers=headers, stralign="right")
|
||||
|
@@ -1,14 +1,11 @@
|
||||
import torch
|
||||
from typing import List
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
from colossalai.fx.passes.split_module import split_module
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
import builtins
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
def apply(*args, **kwargs):
|
||||
@@ -24,16 +21,16 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi
|
||||
origin_node_sharding_spec_dict = {}
|
||||
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
|
||||
strategies_vector = node.strategies_vector
|
||||
setattr(node, 'best_strategy', strategies_vector[strategy_index])
|
||||
setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec)
|
||||
setattr(node, "best_strategy", strategies_vector[strategy_index])
|
||||
setattr(node, "sharding_spec", strategies_vector[strategy_index].output_sharding_spec)
|
||||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec
|
||||
|
||||
# apply the sharding spec of parameters
|
||||
for node in nodes:
|
||||
if node.op == 'call_module':
|
||||
if node.op == "call_module":
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {})
|
||||
setattr(target_module.weight, 'sharding_spec', origin_sharding_spec)
|
||||
setattr(target_module.weight, "sharding_spec", origin_sharding_spec)
|
||||
target_weight_sharding_spec = node.best_strategy.input_shardings[1]
|
||||
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3))
|
||||
apply(target_module.weight, target_weight_sharding_spec)
|
||||
@@ -51,10 +48,10 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi
|
||||
|
||||
# add above dicts into graph
|
||||
for node in nodes:
|
||||
if node.op != 'placeholder':
|
||||
if node.op != "placeholder":
|
||||
with mod_graph.inserting_before(node):
|
||||
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
|
||||
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
|
||||
input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
|
||||
origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
|
||||
break
|
||||
|
||||
return sharding_spec_convert_dict, origin_node_sharding_spec_dict
|
||||
@@ -70,13 +67,13 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
|
||||
node_to_index_dict = {}
|
||||
index = 0
|
||||
for node in nodes:
|
||||
if node.target == 'sharding_spec_convert_dict':
|
||||
if node.target == "sharding_spec_convert_dict":
|
||||
input_dict_node = node
|
||||
continue
|
||||
if node.target == 'origin_node_sharding_spec_dict':
|
||||
if node.target == "origin_node_sharding_spec_dict":
|
||||
origin_dict_node = node
|
||||
continue
|
||||
if not hasattr(node, 'best_strategy'):
|
||||
if not hasattr(node, "best_strategy"):
|
||||
continue
|
||||
node_to_index_dict[node] = index
|
||||
index += 1
|
||||
@@ -84,28 +81,28 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
|
||||
|
||||
# add shape consistency apply function into graph
|
||||
for node in nodes:
|
||||
if not hasattr(node, 'best_strategy'):
|
||||
if not hasattr(node, "best_strategy"):
|
||||
continue
|
||||
with mod_graph.inserting_after(node):
|
||||
origin_spec_node = mod_graph.create_node('call_function',
|
||||
operator.getitem,
|
||||
args=(origin_dict_node, node_to_index_dict[node]))
|
||||
origin_spec_node = mod_graph.create_node(
|
||||
"call_function", operator.getitem, args=(origin_dict_node, node_to_index_dict[node])
|
||||
)
|
||||
with mod_graph.inserting_after(origin_spec_node):
|
||||
set_sharding_spec_node = mod_graph.create_node('call_function',
|
||||
builtins.setattr,
|
||||
args=(node, 'sharding_spec', origin_spec_node))
|
||||
set_sharding_spec_node = mod_graph.create_node(
|
||||
"call_function", builtins.setattr, args=(node, "sharding_spec", origin_spec_node)
|
||||
)
|
||||
|
||||
for user_node in node.strategies_vector.successor_nodes:
|
||||
node_index = user_node.strategies_vector.predecessor_nodes.index(node)
|
||||
with mod_graph.inserting_before(user_node):
|
||||
input_specs_node = mod_graph.create_node('call_function',
|
||||
operator.getitem,
|
||||
args=(input_dict_node, node_to_index_dict[node]))
|
||||
input_specs_node = mod_graph.create_node(
|
||||
"call_function", operator.getitem, args=(input_dict_node, node_to_index_dict[node])
|
||||
)
|
||||
with mod_graph.inserting_before(user_node):
|
||||
sharding_spec_node = mod_graph.create_node('call_function',
|
||||
operator.getitem,
|
||||
args=(input_specs_node, node_index))
|
||||
sharding_spec_node = mod_graph.create_node(
|
||||
"call_function", operator.getitem, args=(input_specs_node, node_index)
|
||||
)
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node))
|
||||
shape_consistency_node = mod_graph.create_node("call_function", apply, args=(node, sharding_spec_node))
|
||||
|
||||
return gm
|
||||
|
@@ -109,13 +109,13 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
return TensorMetadata(None, None, False, None, 0, False)
|
||||
|
||||
tensor_meta = tree_map(extract_tensor_meta, result)
|
||||
n.meta['tensor_meta'] = tensor_meta
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
n.meta["tensor_meta"] = tensor_meta
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
|
||||
setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
|
||||
setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0))
|
||||
n.meta['type'] = type(result)
|
||||
setattr(n, "node_size", activation_size(n.meta.get("fwd_out", 0)) + activation_size(n.meta.get("fwd_tmp", 0)))
|
||||
setattr(n, "fwd_flop", n.meta.get("fwd_flop", 0))
|
||||
setattr(n, "bwd_flop", n.meta.get("bwd_flop", 0))
|
||||
n.meta["type"] = type(result)
|
||||
|
||||
# retain the autograd graph
|
||||
for param in self.module.parameters():
|
||||
@@ -125,7 +125,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
# Main Node running APIs
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``placeholder`` node. Note that this is stateful:
|
||||
``Interpreter`` maintains an internal iterator over
|
||||
@@ -146,7 +146,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
return super().placeholder(target, args, kwargs), GraphInfo()
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``get_attr`` node. Will retrieve an attribute
|
||||
value from the ``Module`` hierarchy of ``self.module``.
|
||||
@@ -165,7 +165,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
return super().get_attr(target, args, kwargs), GraphInfo()
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
@@ -184,7 +184,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
return profile_function(target)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
@@ -202,7 +202,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
return profile_method(target)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
@@ -224,7 +224,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
return profile_module(submod)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute an ``output`` node. This really just retrieves
|
||||
the value referenced by the ``output`` node and returns it.
|
||||
@@ -240,7 +240,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
result (Any): The argument value that was retrieved
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
if hasattr(args[0], '_tensor'):
|
||||
if hasattr(args[0], "_tensor"):
|
||||
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
|
||||
return args[0], GraphInfo(save_fwd_in=True)
|
||||
|
||||
@@ -257,7 +257,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
"""
|
||||
return super().run(*args)
|
||||
|
||||
def summary(self, unit: str = 'MB') -> str:
|
||||
def summary(self, unit: str = "MB") -> str:
|
||||
"""
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
@@ -267,9 +267,11 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library.")
|
||||
print(
|
||||
"`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library."
|
||||
)
|
||||
|
||||
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
|
||||
|
||||
@@ -278,10 +280,10 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
def mem_repr(mem: int) -> str:
|
||||
unit_divisor_map = {
|
||||
'kb': 1024,
|
||||
'mb': 1024**2,
|
||||
'gb': 1024**3,
|
||||
'tb': 1024**4,
|
||||
"kb": 1024,
|
||||
"mb": 1024**2,
|
||||
"gb": 1024**3,
|
||||
"tb": 1024**4,
|
||||
}
|
||||
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
|
||||
|
||||
@@ -292,35 +294,37 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
for node in self.module.graph.nodes:
|
||||
node: Node
|
||||
accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node)
|
||||
node_summaries.append([
|
||||
node.op,
|
||||
str(node),
|
||||
flops_repr(node.meta['fwd_flop']),
|
||||
flops_repr(node.meta['bwd_flop']),
|
||||
mem_repr(accumulate_size),
|
||||
mem_repr(calculate_fwd_in(node)),
|
||||
mem_repr(calculate_fwd_out(node)),
|
||||
mem_repr(calculate_fwd_tmp(node)),
|
||||
mem_repr(node.meta['bwd_mem_out']),
|
||||
mem_repr(node.meta['bwd_mem_tmp']),
|
||||
])
|
||||
node_summaries.append(
|
||||
[
|
||||
node.op,
|
||||
str(node),
|
||||
flops_repr(node.meta["fwd_flop"]),
|
||||
flops_repr(node.meta["bwd_flop"]),
|
||||
mem_repr(accumulate_size),
|
||||
mem_repr(calculate_fwd_in(node)),
|
||||
mem_repr(calculate_fwd_out(node)),
|
||||
mem_repr(calculate_fwd_tmp(node)),
|
||||
mem_repr(node.meta["bwd_mem_out"]),
|
||||
mem_repr(node.meta["bwd_mem_tmp"]),
|
||||
]
|
||||
)
|
||||
|
||||
# Use the ``tabulate`` library to create a well-formatted table
|
||||
# presenting our summary information
|
||||
headers: List[str] = [
|
||||
'Op type',
|
||||
'Op',
|
||||
'Forward FLOPs',
|
||||
'Backward FLOPs',
|
||||
'Accumulated Memory',
|
||||
'FWD_IN',
|
||||
'FWD_OUT',
|
||||
'FWD_TMP',
|
||||
'BWD_OUT',
|
||||
'BWD_TMP',
|
||||
"Op type",
|
||||
"Op",
|
||||
"Forward FLOPs",
|
||||
"Backward FLOPs",
|
||||
"Accumulated Memory",
|
||||
"FWD_IN",
|
||||
"FWD_OUT",
|
||||
"FWD_TMP",
|
||||
"BWD_OUT",
|
||||
"BWD_TMP",
|
||||
]
|
||||
|
||||
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||
return tabulate(node_summaries, headers=headers, stralign="right")
|
||||
|
||||
|
||||
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
|
||||
@@ -344,15 +348,16 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit:
|
||||
Returns:
|
||||
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
|
||||
"""
|
||||
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
interp = MetaInfoProp(gm.to(device))
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
|
||||
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
|
||||
interp.propagate(*args, **kwargs)
|
||||
if verbose:
|
||||
interp.summary(unit)
|
||||
gm.to('cpu')
|
||||
gm.to("cpu")
|
||||
del interp
|
||||
return gm
|
||||
|
@@ -5,7 +5,6 @@ import torch
|
||||
from packaging import version
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split
|
||||
from colossalai.fx.passes.meta_info_prop import TensorMetadata
|
||||
@@ -13,9 +12,9 @@ from colossalai.fx.passes.split_module import Partition
|
||||
|
||||
|
||||
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
|
||||
'''
|
||||
"""
|
||||
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
|
||||
'''
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
valid_children_size = 0
|
||||
valid_children = []
|
||||
@@ -39,40 +38,40 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, parti
|
||||
part_index += 1
|
||||
pp_size -= 1
|
||||
with mod_graph.inserting_after(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
split_node = mod_graph.create_node("call_function", pipe_split)
|
||||
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
|
||||
'''
|
||||
"""
|
||||
This pass will be used in gpt2 test, only a part of changes may be added into
|
||||
split_with_split_nodes_pass, and it will be deprecated in future.
|
||||
'''
|
||||
"""
|
||||
part_idx = 0
|
||||
|
||||
def eliminate_unused_placeholders(gm):
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
if not len(node.users):
|
||||
gm.graph.erase_node(node)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
def refill_outputs_and_placeholders(gm, next_partition_placeholders):
|
||||
'''
|
||||
"""
|
||||
This method is used to eliminate the outputs in previous partition which is unused in next partition.
|
||||
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
|
||||
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
|
||||
to partition 1 and partition 2. However, in single direction linked list, we need to do so.
|
||||
'''
|
||||
"""
|
||||
output_type = None
|
||||
output_args = []
|
||||
non_output_list = []
|
||||
new_placeholder_list = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
if isinstance(node.args[0], (tuple, list)):
|
||||
output_type = node.args[0].__class__
|
||||
output_args.extend([n.name for n in node.args[0]])
|
||||
@@ -114,7 +113,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
||||
continue
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
new_placeholder_list.append(node.name)
|
||||
if output_type is not None:
|
||||
gm.graph.output(output_type(output_args))
|
||||
@@ -125,7 +124,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
||||
|
||||
def split_callback(n: torch.fx.Node):
|
||||
nonlocal part_idx
|
||||
if (n.op, n.target) == ('call_function', pipe_split):
|
||||
if (n.op, n.target) == ("call_function", pipe_split):
|
||||
part_idx += 1
|
||||
return part_idx
|
||||
|
||||
@@ -134,7 +133,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
||||
for name, submodule in split_mod.named_modules():
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
for node in submodule.graph.nodes:
|
||||
if (node.op, node.target) == ('call_function', pipe_split):
|
||||
if (node.op, node.target) == ("call_function", pipe_split):
|
||||
submodule.graph.erase_node(node)
|
||||
submodule.recompile()
|
||||
split_submodules.append(submodule)
|
||||
@@ -200,13 +199,12 @@ def split_module_for_gpt2_test(
|
||||
|
||||
_gen_all_ancestors_set(node)
|
||||
for n in list(all_ancestors):
|
||||
if n.op != 'placeholder' and n._fx_partition > partition_name:
|
||||
if n.op != "placeholder" and n._fx_partition > partition_name:
|
||||
n._fx_partition = partition_name
|
||||
|
||||
def record_cross_partition_use(def_node: torch.fx.node.Node,
|
||||
use_node: Optional[torch.fx.node.Node]): # noqa: B950
|
||||
def_partition_name = getattr(def_node, '_fx_partition', None)
|
||||
use_partition_name = getattr(use_node, '_fx_partition', None)
|
||||
def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
|
||||
def_partition_name = getattr(def_node, "_fx_partition", None)
|
||||
use_partition_name = getattr(use_node, "_fx_partition", None)
|
||||
if def_partition_name != use_partition_name:
|
||||
# if 'tensor_meta' in def_node.meta:
|
||||
# if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
|
||||
@@ -237,7 +235,7 @@ def split_module_for_gpt2_test(
|
||||
|
||||
if node.op in ["placeholder"]:
|
||||
continue
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
# partition_name = str(split_callback(node))
|
||||
# def _set_output_args_partition(n, partition_name):
|
||||
# n._fx_partition = partition_name
|
||||
@@ -252,12 +250,12 @@ def split_module_for_gpt2_test(
|
||||
partitions[partition_name] = partition = Partition(partition_name)
|
||||
|
||||
partition.node_names.append(node.name)
|
||||
origin_partition_name = getattr(node, '_fx_partition', None)
|
||||
origin_partition_name = getattr(node, "_fx_partition", None)
|
||||
if origin_partition_name is None:
|
||||
node._fx_partition = partition_name
|
||||
|
||||
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
|
||||
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
|
||||
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
|
||||
|
||||
# find partitions with no dependencies
|
||||
root_partitions: List[str] = []
|
||||
@@ -287,7 +285,7 @@ def split_module_for_gpt2_test(
|
||||
|
||||
# Transform nodes and collect targets for partition's submodule
|
||||
for node in m.graph.nodes:
|
||||
if hasattr(node, '_fx_partition'):
|
||||
if hasattr(node, "_fx_partition"):
|
||||
partition = partitions[node._fx_partition]
|
||||
|
||||
# swap out old graph nodes in kw/args with references to new nodes in this submodule
|
||||
@@ -295,26 +293,24 @@ def split_module_for_gpt2_test(
|
||||
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
|
||||
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
|
||||
|
||||
if node.op not in ['call_module', 'get_attr']:
|
||||
if node.op not in ["call_module", "get_attr"]:
|
||||
target = node.target
|
||||
else:
|
||||
target_atoms = node.target.split('.')
|
||||
target_atoms = node.target.split(".")
|
||||
target_attr = m
|
||||
for atom in target_atoms:
|
||||
if not hasattr(target_attr, atom):
|
||||
raise RuntimeError(f'Operator target {node.target} not found!')
|
||||
raise RuntimeError(f"Operator target {node.target} not found!")
|
||||
target_attr = getattr(target_attr, atom)
|
||||
# target = target_atoms[-1]
|
||||
target = '_'.join(target_atoms)
|
||||
target = "_".join(target_atoms)
|
||||
partition.targets[target] = target_attr
|
||||
|
||||
assert isinstance(gathered_args, tuple)
|
||||
assert isinstance(gathered_kwargs, dict)
|
||||
new_node = partition.graph.create_node(op=node.op,
|
||||
target=target,
|
||||
args=gathered_args,
|
||||
kwargs=gathered_kwargs,
|
||||
name=node.name)
|
||||
new_node = partition.graph.create_node(
|
||||
op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs, name=node.name
|
||||
)
|
||||
new_node.meta = node.meta.copy()
|
||||
partition.environment[node] = new_node
|
||||
|
||||
@@ -323,14 +319,14 @@ def split_module_for_gpt2_test(
|
||||
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
||||
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
|
||||
for node in m.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
if version.parse(torch.__version__) < version.parse('1.11.0'):
|
||||
if node.op == "placeholder":
|
||||
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
|
||||
else:
|
||||
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
|
||||
type_expr=node.type,
|
||||
default_value=default_value)
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(
|
||||
node.name, type_expr=node.type, default_value=default_value
|
||||
)
|
||||
base_mod_env[node.name].meta = node.meta.copy()
|
||||
|
||||
# Do some things iterating over the partitions in topological order again:
|
||||
@@ -344,13 +340,14 @@ def split_module_for_gpt2_test(
|
||||
|
||||
# Set correct output values
|
||||
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
|
||||
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
|
||||
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
|
||||
partition.graph.output(output_vals)
|
||||
|
||||
# Construct GraphModule for this partition
|
||||
submod_name = f'submod_{partition_name}'
|
||||
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
|
||||
partition.graph) # noqa: B950
|
||||
submod_name = f"submod_{partition_name}"
|
||||
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
|
||||
partition.targets, partition.graph
|
||||
) # noqa: B950
|
||||
|
||||
# Emit call in base graph to this submodule
|
||||
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
|
||||
@@ -358,14 +355,14 @@ def split_module_for_gpt2_test(
|
||||
# Unpack multiple return values from submodule
|
||||
output_val_proxy = torch.fx.proxy.Proxy(output_val)
|
||||
for i, output_name in enumerate(partition.outputs):
|
||||
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
|
||||
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
|
||||
else:
|
||||
if not partition.outputs:
|
||||
continue
|
||||
base_mod_env[list(partition.outputs)[0]] = output_val
|
||||
|
||||
for node in m.graph.nodes:
|
||||
if node.op == 'output':
|
||||
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
|
||||
if node.op == "output":
|
||||
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
|
||||
|
||||
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
||||
|
@@ -9,8 +9,19 @@ from colossalai.legacy.tensor.distspec import ShardSpec
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
ELEMENTWISE_FUNC_OP = [
|
||||
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
|
||||
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
|
||||
torch.add,
|
||||
operator.add,
|
||||
torch.abs,
|
||||
torch.cos,
|
||||
torch.exp,
|
||||
torch.mul,
|
||||
operator.mul,
|
||||
operator.floordiv,
|
||||
operator.truediv,
|
||||
operator.neg,
|
||||
torch.multiply,
|
||||
torch.nn.functional.relu,
|
||||
torch.nn.functional.dropout,
|
||||
]
|
||||
|
||||
|
||||
@@ -72,7 +83,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
|
||||
# traverse the graph to look for consecutive linear layers
|
||||
is_linear_module = False
|
||||
|
||||
if node.op == 'call_module':
|
||||
if node.op == "call_module":
|
||||
# look for the linear layer
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
if isinstance(module, nn.Linear):
|
||||
@@ -82,31 +93,31 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
|
||||
# it means the first linear has been found and the current module
|
||||
# is the second linear
|
||||
# set the current linear module to be row-sharded
|
||||
annotation_record['row'] = module
|
||||
annotation_record["row"] = module
|
||||
|
||||
for shard_type, module in annotation_record.items():
|
||||
# add row sharding spec
|
||||
if shard_type == 'row':
|
||||
if shard_type == "row":
|
||||
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
|
||||
comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
setattr(module.weight, 'pg', process_group)
|
||||
setattr(module.weight, 'dist_spec', dist_spec)
|
||||
setattr(module.weight, 'comp_spec', comp_spec)
|
||||
elif shard_type == 'col':
|
||||
setattr(module.weight, "pg", process_group)
|
||||
setattr(module.weight, "dist_spec", dist_spec)
|
||||
setattr(module.weight, "comp_spec", comp_spec)
|
||||
elif shard_type == "col":
|
||||
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
|
||||
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
weight_comp_spec.output_replicate = False
|
||||
setattr(module.weight, 'pg', process_group)
|
||||
setattr(module.weight, 'dist_spec', weight_dist_spec)
|
||||
setattr(module.weight, 'comp_spec', weight_comp_spec)
|
||||
setattr(module.weight, "pg", process_group)
|
||||
setattr(module.weight, "dist_spec", weight_dist_spec)
|
||||
setattr(module.weight, "comp_spec", weight_comp_spec)
|
||||
|
||||
if module.bias is not None:
|
||||
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
|
||||
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
bias_comp_spec.output_replicate = False
|
||||
setattr(module.bias, 'pg', process_group)
|
||||
setattr(module.bias, 'dist_spec', bias_dist_spec)
|
||||
setattr(module.bias, 'comp_spec', bias_comp_spec)
|
||||
setattr(module.bias, "pg", process_group)
|
||||
setattr(module.bias, "dist_spec", bias_dist_spec)
|
||||
setattr(module.bias, "comp_spec", bias_comp_spec)
|
||||
start_tracking = False
|
||||
annotation_record.clear()
|
||||
else:
|
||||
@@ -114,16 +125,16 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
|
||||
# it means the current layer is the first linear
|
||||
# set the linear layer to be col-sharded
|
||||
start_tracking = True
|
||||
annotation_record['col'] = module
|
||||
annotation_record["col"] = module
|
||||
|
||||
if start_tracking and not is_linear_module:
|
||||
# check against the white list
|
||||
# if non-element wise op is found, we reset the tracking
|
||||
if node.op == 'call_module':
|
||||
if node.op == "call_module":
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
if module.__class__ not in ELEMENTWISE_MODULE_OP:
|
||||
start_tracking = False
|
||||
elif node.op == 'call_function' or node.op == 'call_method':
|
||||
elif node.op == "call_function" or node.op == "call_method":
|
||||
if node.target not in ELEMENTWISE_FUNC_OP:
|
||||
start_tracking = False
|
||||
elif len(node.users.keys()) > 1:
|
||||
|
@@ -25,12 +25,14 @@ class Partition:
|
||||
self.targets: Dict[str, Any] = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"name: {self.name},\n" \
|
||||
f" nodes: {self.node_names},\n" \
|
||||
f" inputs: {self.inputs},\n" \
|
||||
f" outputs: {self.outputs},\n" \
|
||||
f" partitions dependent on: {self.partitions_dependent_on},\n" \
|
||||
return (
|
||||
f"name: {self.name},\n"
|
||||
f" nodes: {self.node_names},\n"
|
||||
f" inputs: {self.inputs},\n"
|
||||
f" outputs: {self.outputs},\n"
|
||||
f" partitions dependent on: {self.partitions_dependent_on},\n"
|
||||
f" partition dependents: {self.partition_dependents}"
|
||||
)
|
||||
|
||||
|
||||
# Creates subgraphs out of main graph
|
||||
@@ -117,10 +119,9 @@ def split_module(
|
||||
partitions: Dict[str, Partition] = {}
|
||||
orig_nodes: Dict[str, torch.fx.node.Node] = {}
|
||||
|
||||
def record_cross_partition_use(def_node: torch.fx.node.Node,
|
||||
use_node: Optional[torch.fx.node.Node]): # noqa: B950
|
||||
def_partition_name = getattr(def_node, '_fx_partition', None)
|
||||
use_partition_name = getattr(use_node, '_fx_partition', None)
|
||||
def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
|
||||
def_partition_name = getattr(def_node, "_fx_partition", None)
|
||||
use_partition_name = getattr(use_node, "_fx_partition", None)
|
||||
if def_partition_name != use_partition_name:
|
||||
if def_partition_name is not None:
|
||||
def_partition = partitions[def_partition_name]
|
||||
@@ -134,7 +135,7 @@ def split_module(
|
||||
if def_partition_name is not None:
|
||||
use_partition.partitions_dependent_on.setdefault(def_partition_name)
|
||||
|
||||
def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
|
||||
def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
|
||||
def_partition_name = getattr(def_node, "_fx_partition", None)
|
||||
use_partition_name = getattr(use_node, "_fx_partition", None)
|
||||
if def_partition_name != use_partition_name:
|
||||
@@ -161,7 +162,7 @@ def split_module(
|
||||
|
||||
if node.op in ["placeholder"]:
|
||||
continue
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
if merge_output:
|
||||
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
|
||||
else:
|
||||
@@ -178,7 +179,7 @@ def split_module(
|
||||
node._fx_partition = partition_name
|
||||
|
||||
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
|
||||
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
|
||||
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
|
||||
|
||||
# find partitions with no dependencies
|
||||
root_partitions: List[str] = []
|
||||
@@ -208,7 +209,7 @@ def split_module(
|
||||
|
||||
# Transform nodes and collect targets for partition's submodule
|
||||
for node in m.graph.nodes:
|
||||
if hasattr(node, '_fx_partition'):
|
||||
if hasattr(node, "_fx_partition"):
|
||||
partition = partitions[node._fx_partition]
|
||||
|
||||
# swap out old graph nodes in kw/args with references to new nodes in this submodule
|
||||
@@ -216,25 +217,24 @@ def split_module(
|
||||
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
|
||||
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
|
||||
|
||||
if node.op not in ['call_module', 'get_attr']:
|
||||
if node.op not in ["call_module", "get_attr"]:
|
||||
target = node.target
|
||||
else:
|
||||
target_atoms = node.target.split('.')
|
||||
target_atoms = node.target.split(".")
|
||||
target_attr = m
|
||||
for atom in target_atoms:
|
||||
if not hasattr(target_attr, atom):
|
||||
raise RuntimeError(f'Operator target {node.target} not found!')
|
||||
raise RuntimeError(f"Operator target {node.target} not found!")
|
||||
target_attr = getattr(target_attr, atom)
|
||||
# target = target_atoms[-1]
|
||||
target = '_'.join(target_atoms)
|
||||
target = "_".join(target_atoms)
|
||||
partition.targets[target] = target_attr
|
||||
|
||||
assert isinstance(gathered_args, tuple)
|
||||
assert isinstance(gathered_kwargs, dict)
|
||||
new_node = partition.graph.create_node(op=node.op,
|
||||
target=target,
|
||||
args=gathered_args,
|
||||
kwargs=gathered_kwargs)
|
||||
new_node = partition.graph.create_node(
|
||||
op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs
|
||||
)
|
||||
new_node.meta = node.meta.copy()
|
||||
partition.environment[node] = new_node
|
||||
|
||||
@@ -243,14 +243,14 @@ def split_module(
|
||||
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
||||
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
|
||||
for node in m.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
if version.parse(torch.__version__) < version.parse('1.11.0'):
|
||||
if node.op == "placeholder":
|
||||
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
|
||||
else:
|
||||
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.target,
|
||||
type_expr=node.type,
|
||||
default_value=default_value)
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(
|
||||
node.target, type_expr=node.type, default_value=default_value
|
||||
)
|
||||
base_mod_env[node.name].meta = node.meta.copy()
|
||||
|
||||
# Do some things iterating over the partitions in topological order again:
|
||||
@@ -264,13 +264,14 @@ def split_module(
|
||||
|
||||
# Set correct output values
|
||||
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
|
||||
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
|
||||
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
|
||||
partition.graph.output(output_vals)
|
||||
|
||||
# Construct GraphModule for this partition
|
||||
submod_name = f'submod_{partition_name}'
|
||||
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
|
||||
partition.graph) # noqa: B950
|
||||
submod_name = f"submod_{partition_name}"
|
||||
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
|
||||
partition.targets, partition.graph
|
||||
) # noqa: B950
|
||||
|
||||
# Emit call in base graph to this submodule
|
||||
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
|
||||
@@ -278,15 +279,15 @@ def split_module(
|
||||
# Unpack multiple return values from submodule
|
||||
output_val_proxy = torch.fx.proxy.Proxy(output_val)
|
||||
for i, output_name in enumerate(partition.outputs):
|
||||
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
|
||||
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
|
||||
else:
|
||||
if not partition.outputs:
|
||||
continue
|
||||
base_mod_env[list(partition.outputs)[0]] = output_val
|
||||
|
||||
for node in m.graph.nodes:
|
||||
if node.op == 'output':
|
||||
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
|
||||
if node.op == "output":
|
||||
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
|
||||
|
||||
for partition_name in sorted_partitions:
|
||||
partition = partitions[partition_name]
|
||||
|
@@ -1,7 +1,9 @@
|
||||
import torch
|
||||
from typing import Dict
|
||||
from torch.fx.node import Node, map_arg
|
||||
|
||||
import torch
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.node import Node, map_arg
|
||||
|
||||
|
||||
def get_comm_size(prev_partition, next_partition):
|
||||
"""
|
||||
@@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition):
|
||||
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
|
||||
for n in input_nodes:
|
||||
if n.name in parent_node_names and n not in visited_nodes:
|
||||
comm_size += n.meta['tensor_meta'].numel
|
||||
comm_size += n.meta["tensor_meta"].numel
|
||||
visited_nodes.add(n)
|
||||
return comm_size
|
||||
|
||||
@@ -36,12 +38,12 @@ def get_leaf(graph: Graph):
|
||||
"""
|
||||
input_nodes: Dict[Node, None] = {}
|
||||
for node in graph.nodes:
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
map_arg(node.args, lambda n: input_nodes.setdefault(n))
|
||||
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
|
||||
placeholder_nodes = []
|
||||
for node in input_nodes.keys():
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
placeholder_nodes.append(node)
|
||||
for node in placeholder_nodes:
|
||||
input_nodes.pop(node)
|
||||
@@ -60,13 +62,13 @@ def get_top(graph: Graph):
|
||||
"""
|
||||
top_node_list = set()
|
||||
for node in graph.nodes:
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
continue
|
||||
is_top = False
|
||||
|
||||
def _get_top(node):
|
||||
nonlocal is_top
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
is_top = True
|
||||
|
||||
map_arg(node.args, lambda n: _get_top(n))
|
||||
@@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node):
|
||||
def get_all_consumers(graph: Graph, node: Node):
|
||||
"""
|
||||
Given a graph and a node of this graph, return all consumers of the node.
|
||||
|
||||
|
||||
Returns:
|
||||
List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``.
|
||||
"""
|
||||
@@ -120,7 +122,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
|
||||
for node in gm.graph.nodes:
|
||||
if hasattr(node, 'bfs_level'):
|
||||
print(node.name, node.bfs_level)
|
||||
|
||||
|
||||
Output:
|
||||
graph():
|
||||
%x : [#users=2] = placeholder[target=x]
|
||||
@@ -148,7 +150,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
|
||||
while nodes_to_process:
|
||||
new_process_list = []
|
||||
for node in nodes_to_process:
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
continue
|
||||
node.bfs_level = current_level
|
||||
new_process_list.extend(get_all_consumers(graph, node))
|
||||
@@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module:
|
||||
torch.nn.Module: the module associated with the given node
|
||||
"""
|
||||
|
||||
assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object'
|
||||
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
|
||||
assert (
|
||||
node.graph.owning_module is not None
|
||||
), "Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object"
|
||||
assert node.op == "call_module", f"Expected node.op to be call_module, but found {node.op}"
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
return module
|
||||
|
||||
|
Reference in New Issue
Block a user