mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -5,7 +5,6 @@ import torch
|
||||
from packaging import version
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split
|
||||
from colossalai.fx.passes.meta_info_prop import TensorMetadata
|
||||
@@ -13,9 +12,9 @@ from colossalai.fx.passes.split_module import Partition
|
||||
|
||||
|
||||
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
|
||||
'''
|
||||
"""
|
||||
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
|
||||
'''
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
valid_children_size = 0
|
||||
valid_children = []
|
||||
@@ -39,40 +38,40 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, parti
|
||||
part_index += 1
|
||||
pp_size -= 1
|
||||
with mod_graph.inserting_after(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
split_node = mod_graph.create_node("call_function", pipe_split)
|
||||
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
|
||||
'''
|
||||
"""
|
||||
This pass will be used in gpt2 test, only a part of changes may be added into
|
||||
split_with_split_nodes_pass, and it will be deprecated in future.
|
||||
'''
|
||||
"""
|
||||
part_idx = 0
|
||||
|
||||
def eliminate_unused_placeholders(gm):
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
if not len(node.users):
|
||||
gm.graph.erase_node(node)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
def refill_outputs_and_placeholders(gm, next_partition_placeholders):
|
||||
'''
|
||||
"""
|
||||
This method is used to eliminate the outputs in previous partition which is unused in next partition.
|
||||
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
|
||||
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
|
||||
to partition 1 and partition 2. However, in single direction linked list, we need to do so.
|
||||
'''
|
||||
"""
|
||||
output_type = None
|
||||
output_args = []
|
||||
non_output_list = []
|
||||
new_placeholder_list = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
if isinstance(node.args[0], (tuple, list)):
|
||||
output_type = node.args[0].__class__
|
||||
output_args.extend([n.name for n in node.args[0]])
|
||||
@@ -114,7 +113,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
||||
continue
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
new_placeholder_list.append(node.name)
|
||||
if output_type is not None:
|
||||
gm.graph.output(output_type(output_args))
|
||||
@@ -125,7 +124,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
||||
|
||||
def split_callback(n: torch.fx.Node):
|
||||
nonlocal part_idx
|
||||
if (n.op, n.target) == ('call_function', pipe_split):
|
||||
if (n.op, n.target) == ("call_function", pipe_split):
|
||||
part_idx += 1
|
||||
return part_idx
|
||||
|
||||
@@ -134,7 +133,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
|
||||
for name, submodule in split_mod.named_modules():
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
for node in submodule.graph.nodes:
|
||||
if (node.op, node.target) == ('call_function', pipe_split):
|
||||
if (node.op, node.target) == ("call_function", pipe_split):
|
||||
submodule.graph.erase_node(node)
|
||||
submodule.recompile()
|
||||
split_submodules.append(submodule)
|
||||
@@ -200,13 +199,12 @@ def split_module_for_gpt2_test(
|
||||
|
||||
_gen_all_ancestors_set(node)
|
||||
for n in list(all_ancestors):
|
||||
if n.op != 'placeholder' and n._fx_partition > partition_name:
|
||||
if n.op != "placeholder" and n._fx_partition > partition_name:
|
||||
n._fx_partition = partition_name
|
||||
|
||||
def record_cross_partition_use(def_node: torch.fx.node.Node,
|
||||
use_node: Optional[torch.fx.node.Node]): # noqa: B950
|
||||
def_partition_name = getattr(def_node, '_fx_partition', None)
|
||||
use_partition_name = getattr(use_node, '_fx_partition', None)
|
||||
def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
|
||||
def_partition_name = getattr(def_node, "_fx_partition", None)
|
||||
use_partition_name = getattr(use_node, "_fx_partition", None)
|
||||
if def_partition_name != use_partition_name:
|
||||
# if 'tensor_meta' in def_node.meta:
|
||||
# if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
|
||||
@@ -237,7 +235,7 @@ def split_module_for_gpt2_test(
|
||||
|
||||
if node.op in ["placeholder"]:
|
||||
continue
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
# partition_name = str(split_callback(node))
|
||||
# def _set_output_args_partition(n, partition_name):
|
||||
# n._fx_partition = partition_name
|
||||
@@ -252,12 +250,12 @@ def split_module_for_gpt2_test(
|
||||
partitions[partition_name] = partition = Partition(partition_name)
|
||||
|
||||
partition.node_names.append(node.name)
|
||||
origin_partition_name = getattr(node, '_fx_partition', None)
|
||||
origin_partition_name = getattr(node, "_fx_partition", None)
|
||||
if origin_partition_name is None:
|
||||
node._fx_partition = partition_name
|
||||
|
||||
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
|
||||
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
|
||||
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
|
||||
|
||||
# find partitions with no dependencies
|
||||
root_partitions: List[str] = []
|
||||
@@ -287,7 +285,7 @@ def split_module_for_gpt2_test(
|
||||
|
||||
# Transform nodes and collect targets for partition's submodule
|
||||
for node in m.graph.nodes:
|
||||
if hasattr(node, '_fx_partition'):
|
||||
if hasattr(node, "_fx_partition"):
|
||||
partition = partitions[node._fx_partition]
|
||||
|
||||
# swap out old graph nodes in kw/args with references to new nodes in this submodule
|
||||
@@ -295,26 +293,24 @@ def split_module_for_gpt2_test(
|
||||
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
|
||||
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
|
||||
|
||||
if node.op not in ['call_module', 'get_attr']:
|
||||
if node.op not in ["call_module", "get_attr"]:
|
||||
target = node.target
|
||||
else:
|
||||
target_atoms = node.target.split('.')
|
||||
target_atoms = node.target.split(".")
|
||||
target_attr = m
|
||||
for atom in target_atoms:
|
||||
if not hasattr(target_attr, atom):
|
||||
raise RuntimeError(f'Operator target {node.target} not found!')
|
||||
raise RuntimeError(f"Operator target {node.target} not found!")
|
||||
target_attr = getattr(target_attr, atom)
|
||||
# target = target_atoms[-1]
|
||||
target = '_'.join(target_atoms)
|
||||
target = "_".join(target_atoms)
|
||||
partition.targets[target] = target_attr
|
||||
|
||||
assert isinstance(gathered_args, tuple)
|
||||
assert isinstance(gathered_kwargs, dict)
|
||||
new_node = partition.graph.create_node(op=node.op,
|
||||
target=target,
|
||||
args=gathered_args,
|
||||
kwargs=gathered_kwargs,
|
||||
name=node.name)
|
||||
new_node = partition.graph.create_node(
|
||||
op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs, name=node.name
|
||||
)
|
||||
new_node.meta = node.meta.copy()
|
||||
partition.environment[node] = new_node
|
||||
|
||||
@@ -323,14 +319,14 @@ def split_module_for_gpt2_test(
|
||||
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
||||
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
|
||||
for node in m.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
if version.parse(torch.__version__) < version.parse('1.11.0'):
|
||||
if node.op == "placeholder":
|
||||
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
|
||||
else:
|
||||
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
|
||||
type_expr=node.type,
|
||||
default_value=default_value)
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(
|
||||
node.name, type_expr=node.type, default_value=default_value
|
||||
)
|
||||
base_mod_env[node.name].meta = node.meta.copy()
|
||||
|
||||
# Do some things iterating over the partitions in topological order again:
|
||||
@@ -344,13 +340,14 @@ def split_module_for_gpt2_test(
|
||||
|
||||
# Set correct output values
|
||||
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
|
||||
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
|
||||
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
|
||||
partition.graph.output(output_vals)
|
||||
|
||||
# Construct GraphModule for this partition
|
||||
submod_name = f'submod_{partition_name}'
|
||||
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
|
||||
partition.graph) # noqa: B950
|
||||
submod_name = f"submod_{partition_name}"
|
||||
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
|
||||
partition.targets, partition.graph
|
||||
) # noqa: B950
|
||||
|
||||
# Emit call in base graph to this submodule
|
||||
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
|
||||
@@ -358,14 +355,14 @@ def split_module_for_gpt2_test(
|
||||
# Unpack multiple return values from submodule
|
||||
output_val_proxy = torch.fx.proxy.Proxy(output_val)
|
||||
for i, output_name in enumerate(partition.outputs):
|
||||
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
|
||||
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
|
||||
else:
|
||||
if not partition.outputs:
|
||||
continue
|
||||
base_mod_env[list(partition.outputs)[0]] = output_val
|
||||
|
||||
for node in m.graph.nodes:
|
||||
if node.op == 'output':
|
||||
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
|
||||
if node.op == "output":
|
||||
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
|
||||
|
||||
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
||||
|
Reference in New Issue
Block a user