[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -14,18 +14,20 @@ from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
target_sharding_spec: ShardingSpec) -> ShardMetaInfo:
def _construct_shard_meta_info(
node: Node, origin_sharding_spec: ShardingSpec, target_sharding_spec: ShardingSpec
) -> ShardMetaInfo:
# get comm_action_sequence and total_cost from shape_consistency_manager
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
origin_sharding_spec, target_sharding_spec
)
meta_info = ShardMetaInfo()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for ShardMetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
# extract user that has _meta_data and extract element length
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
input_node = next(n for n in node._input_nodes if hasattr(n, "_meta_data"))
element_length = input_node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
@@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
meta_info.memory_cost = mem_cost
# get computation cost for ShardMetaInfo
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
meta_info.compute_cost = TrainCycleItem(
total_cost["forward"] * element_length,
total_cost["backward"] * element_length,
total_cost["total"] * element_length,
)
# get tensor shape for ShardMetaInfo
origin_sharding_spec: ShardingSpec
@@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
output_shape = target_sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
return meta_info
@@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
# extract node index and user node index
args = node.args
node_index, user_node_index = args[3], args[4]
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
user_node_index]
origin_sharding_spec, target_sharding_spec = (
origin_spec_dict[node_index],
sharding_spec_dict[node_index][user_node_index],
)
return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
@@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S
# this case is for all_reduce, there will be no memory cost
meta_info = ShardMetaInfo()
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
output_node = next(n for n in node.users if hasattr(n, "_meta_data"))
element_length = output_node._meta_data.element_size()
total_cost = comm_action.comm_spec.get_comm_cost()
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
meta_info.compute_cost = TrainCycleItem(
total_cost["forward"] * element_length,
total_cost["backward"] * element_length,
total_cost["total"] * element_length,
)
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
else:
# this case will be handled by shape consistency manager
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
'tgt_spec']
origin_sharding_spec, target_sharding_spec = (
comm_action.comm_spec["src_spec"],
comm_action.comm_spec["tgt_spec"],
)
meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
return meta_info
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
comm_actions_dict: Dict) -> GraphModule:
def comm_metainfo_pass(
gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, comm_actions_dict: Dict
) -> GraphModule:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for node in gm.graph.nodes:
if node.target == runtime_apply:
setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
setattr(node, "best_strategy_info", _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply:
setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
setattr(node, "best_strategy_info", _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else:
pass
return gm

View File

@@ -21,16 +21,15 @@ def _normalize_tuple(x):
@compatibility(is_backward_compatible=False)
class MetaInfoProp:
def __init__(self, module: GraphModule) -> None:
self.module = module
self.func_dict = {
'placeholder': self.placeholder_handler,
'get_attr': self.get_attr_handler,
'output': self.output_handler,
'call_function': self.node_handler,
'call_module': self.node_handler,
'call_method': self.node_handler,
"placeholder": self.placeholder_handler,
"get_attr": self.get_attr_handler,
"output": self.output_handler,
"call_function": self.node_handler,
"call_module": self.node_handler,
"call_method": self.node_handler,
}
def _set_data_ptr(self, x):
@@ -46,7 +45,7 @@ class MetaInfoProp:
"""
Check if the node is inplace operation.
"""
if node.op == 'call_module':
if node.op == "call_module":
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS
@@ -66,7 +65,7 @@ class MetaInfoProp:
Handle the placeholder node.
"""
graph_info = GraphInfo()
out = _normalize_tuple(getattr(node, '_meta_data', None))
out = _normalize_tuple(getattr(node, "_meta_data", None))
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}
@@ -96,7 +95,7 @@ class MetaInfoProp:
"""
Handle other kind of nodes
"""
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_strategy_info
meta_info: ShardMetaInfo
@@ -126,7 +125,8 @@ class MetaInfoProp:
for tensor in par.meta.get("fwd_out", []):
tensor: torch.Tensor
target_input_tensor = next(
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None
)
if target_input_tensor is not None:
target_input_tensor.data_ptr = tensor.data_ptr

View File

@@ -1,18 +1,10 @@
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
@@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
user_node_index: int):
def runtime_apply_for_iterable_object(
node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int
):
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
"""
rst = []
for index, (origin_sharding_spec,
target_sharding_spec) in enumerate(zip(origin_dict[node_index],
input_dict[node_index][user_node_index])):
for index, (origin_sharding_spec, target_sharding_spec) in enumerate(
zip(origin_dict[node_index], input_dict[node_index][user_node_index])
):
rst.append(
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
target_sharding_spec))
shape_consistency_manager.apply_for_autoparallel_runtime(
node[index], origin_sharding_spec, target_sharding_spec
)
)
rst = type(node)(rst)
return rst
@@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_
if isinstance(comm_action.comm_spec, CommSpec):
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
else:
origin_sharding_spec = comm_action.comm_spec['src_spec']
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
origin_sharding_spec = comm_action.comm_spec["src_spec"]
tgt_sharding_spec = comm_action.comm_spec["tgt_spec"]
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
return rst
@@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]):
node_to_index_dict = {}
index = 0
for node in nodes:
if node.target == 'sharding_spec_convert_dict':
if node.target == "sharding_spec_convert_dict":
input_dict_node = node
continue
if node.target == 'origin_node_sharding_spec_dict':
if node.target == "origin_node_sharding_spec_dict":
origin_dict_node = node
continue
if node.target == 'comm_actions_dict':
if node.target == "comm_actions_dict":
comm_actions_dict_node = node
continue
if not hasattr(node, 'best_strategy'):
if not hasattr(node, "best_strategy"):
continue
node_to_index_dict[node] = index
index += 1
@@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
if not hasattr(node, 'best_strategy') or node.op == 'output':
if not hasattr(node, "best_strategy") or node.op == "output":
continue
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
if isinstance(node.sharding_spec, (list, tuple)):
assert isinstance(
node.target_sharding_specs,
(list,
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
node.target_sharding_specs, (list, tuple)
), "target sharding specs should be tuple or list when node.sharding_spec is tuple or list"
total_difference = 0
for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
node.target_sharding_specs[user_node_index]):
for sharding_spec, target_sharding_spec in zip(
node.sharding_spec, node.target_sharding_specs[user_node_index]
):
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
if total_difference == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply_for_iterable_object,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
shape_consistency_node = mod_graph.create_node(
"call_function",
runtime_apply_for_iterable_object,
args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
)
else:
assert isinstance(node.sharding_spec,
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
assert isinstance(
node.sharding_spec, ShardingSpec
), "node.sharding_spec should be type of ShardingSpec, tuple or list."
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
if hasattr(user_node.meta['info'], 'activation_checkpoint'):
MetaInfo(shape_consistency_node,
mod_dir=user_node.meta['info'].mod_dir,
activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
shape_consistency_node = mod_graph.create_node(
"call_function",
runtime_apply,
args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
)
if hasattr(user_node.meta["info"], "activation_checkpoint"):
MetaInfo(
shape_consistency_node,
mod_dir=user_node.meta["info"].mod_dir,
activation_checkpoint=tuple(user_node.meta["info"].activation_checkpoint),
)
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node
@@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
_, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
if not hasattr(node, 'best_strategy') or node.op == 'output':
if not hasattr(node, "best_strategy") or node.op == "output":
continue
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
if comm_action.comm_type == CommType.HOOK:
continue
if comm_action.comm_type == CommType.BEFORE:
@@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
else:
comm_object = node.args[comm_action.arg_index]
with mod_graph.inserting_before(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
comm_spec_apply_node = mod_graph.create_node(
"call_function",
runtime_comm_spec_apply,
args=(comm_object, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
)
# the origin node may be a positional argument or key word argument of user node
if comm_action.key_for_kwarg is not None:
# substitute the origin node with comm_spec_apply_node
@@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
comm_spec_apply_node = mod_graph.create_node('call_function',
runtime_comm_spec_apply,
args=(node, comm_actions_dict_node,
node_to_index_dict[node], op_data.name))
comm_spec_apply_node = mod_graph.create_node(
"call_function",
runtime_comm_spec_apply,
args=(node, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
)
user_list = list(node.users.keys())
for user in user_list:
if user == comm_spec_apply_node:
@@ -211,10 +212,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
if hasattr(node.meta['info'], 'activation_checkpoint'):
MetaInfo(comm_spec_apply_node,
mod_dir=node.meta['info'].mod_dir,
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
if hasattr(node.meta["info"], "activation_checkpoint"):
MetaInfo(
comm_spec_apply_node,
mod_dir=node.meta["info"].mod_dir,
activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
)
return gm
@@ -227,21 +230,21 @@ def _act_annotation_pass(gm: torch.fx.GraphModule):
nodes = tuple(mod_graph.nodes)
for node in nodes:
if not hasattr(node.meta, 'activation_checkpoint'):
from .runtime_preparation_pass import size_processing
if not hasattr(node.meta, "activation_checkpoint"):
pass
user_act_annotation = -1
input_act_annotation = -1
for user_node in node.users.keys():
if 'activation_checkpoint' in user_node.meta:
user_act_annotation = user_node.meta['activation_checkpoint']
if "activation_checkpoint" in user_node.meta:
user_act_annotation = user_node.meta["activation_checkpoint"]
break
for input_node in node._input_nodes.keys():
if 'activation_checkpoint' in input_node.meta:
input_act_annotation = input_node.meta['activation_checkpoint']
if "activation_checkpoint" in input_node.meta:
input_act_annotation = input_node.meta["activation_checkpoint"]
break
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
node.meta['activation_checkpoint'] = user_act_annotation
node.meta["activation_checkpoint"] = user_act_annotation
return gm

View File

@@ -1,19 +1,12 @@
import operator
from copy import deepcopy
from typing import Dict, List, Union
import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationDataType,
ShardingStrategy,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import _all_reduce
@@ -25,11 +18,13 @@ from .constants import SHAPE_ARGUMENT_OPS
shape_consistency_manager = ShapeConsistencyManager()
def size_processing(size: Union[int, torch.Size],
dim_partition_dict: Dict[int, List[int]],
device_mesh_info: Dict[int, int],
target_dim: int = None,
node_name: str = None):
def size_processing(
size: Union[int, torch.Size],
dim_partition_dict: Dict[int, List[int]],
device_mesh_info: Dict[int, int],
target_dim: int = None,
node_name: str = None,
):
"""
This method will be invoked during runtime to convert size node value depending on distributed information.
"""
@@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size],
return size
def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
strategies_constructor: StrategiesConstructor):
def solution_annotation_pass(
gm: torch.fx.GraphModule, solution: List[int], strategies_constructor: StrategiesConstructor
):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
@@ -70,14 +66,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
# stick the solution strategy to the corresponding node
setattr(node, 'best_strategy', strategies_vector[strategy_index])
setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
setattr(node, "best_strategy", strategies_vector[strategy_index])
setattr(node, "sharding_spec", strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
str(node))
str(node)
)
# attach the corresponding metainfo if node has the attribute `strategies_info`
if hasattr(node, 'strategies_info'):
setattr(node, 'best_strategy_info', node.strategies_info[strategy_index])
if hasattr(node, "strategies_info"):
setattr(node, "best_strategy_info", node.strategies_info[strategy_index])
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
@@ -92,15 +89,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
setattr(node, 'target_sharding_specs', target_sharding_specs)
setattr(node, "target_sharding_specs", target_sharding_specs)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if node.op == 'get_attr':
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
if node.op == "get_attr":
assert len(target_sharding_specs) == 1, f"sharing weight is not supported in current version."
target_node = node.strategies_vector.successor_nodes[0]
node_name = str(node)
if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP:
if target_node.op == "call_function" and target_node.target in RESHAPE_FUNC_OP:
node_name = str(target_node)
target_node = target_node.strategies_vector.successor_nodes[0]
user_strategy = target_node.best_strategy
@@ -122,11 +119,11 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
# add above dicts into graph
for node in nodes:
if node.op != 'placeholder':
if node.op != "placeholder":
with mod_graph.inserting_before(node):
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
comm_actions_dict_node = mod_graph.create_node("placeholder", target="comm_actions_dict")
break
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
@@ -148,7 +145,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
device_mesh_info[dim] = dim_size
def _extract_target_dim(node):
'''
"""
A helper function to extract the target dimension from size node.
There are two usages of torch.Tensor.size:
1. tensor.size()
@@ -156,7 +153,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
Otherwise, the output will be in type of torch.Size and this function will return None.
'''
"""
target_dim = None
if len(node.args) > 1:
target_dim = node.args[1]
@@ -165,19 +162,21 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
return target_dim
def _post_processing(node, size_processing_node):
'''
"""
This function is used to process the dependency between the size node and its users after
inserting the size_process_node.
'''
"""
# store original node and processing node pair in node_pairs dictionary
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
if hasattr(node.meta['info'], 'activation_checkpoint'):
MetaInfo(size_processing_node,
mod_dir=node.meta['info'].mod_dir,
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
if hasattr(node.meta["info"], "activation_checkpoint"):
MetaInfo(
size_processing_node,
mod_dir=node.meta["info"].mod_dir,
activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
)
user_list = list(node.users.keys())
for user in user_list:
@@ -196,10 +195,10 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
user.kwargs = new_kwargs
def _update_slice_object_args(slice_object):
'''
"""
This function is used to update the slice object argument list.
If the slice object contains the Node argument, then the size node will be replaced with
'''
"""
if isinstance(slice_object, slice):
start = slice_object.start
stop = slice_object.stop
@@ -220,8 +219,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
for node in nodes:
if node.op == 'call_method' and node.target == 'size':
if node.op == "call_method" and node.target == "size":
# extract useful information from size node
# dim_partition_dict will instruct the size value on which
# dimension should be enlarged.
@@ -232,14 +230,14 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# insert size_processing node
with mod_graph.inserting_after(node):
size_processing_node = mod_graph.create_node('call_function',
size_processing,
args=(node, dim_partition_dict, device_mesh_info,
target_dim, node.name))
size_processing_node = mod_graph.create_node(
"call_function",
size_processing,
args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name),
)
_post_processing(node, size_processing_node)
if node.op == 'call_function' and node.target == operator.getitem:
if node.op == "call_function" and node.target == operator.getitem:
getitem_index = node.args[1]
# slice object is quite special in torch.fx graph,
# On one side, we treat slice object same as type of int,
@@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
nodes = tuple(mod_graph.nodes)
def _extract_info_from_sharding_spec(sharding_spec):
'''
"""
This function is used to extract the dim_partition_dict and device_mesh from
sharding spec instance or a list of sharding spec.
'''
"""
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
assert isinstance(sharding_spec,
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
assert isinstance(
sharding_spec, (tuple, list)
), "sharding_spec should be type of ShardingSpec, tuple, list or None"
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
@@ -322,8 +321,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
else:
new_args.append(arg)
else:
assert isinstance(arg,
(int, tuple, list)), 'The argument in view node should be either type of Node or int.'
assert isinstance(
arg, (int, tuple, list)
), "The argument in view node should be either type of Node or int."
if isinstance(arg, (tuple, list)):
new_args.extend(arg)
else:
@@ -332,7 +332,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
new_args = _process_node_arguments(node)
if node.op == 'call_method':
if node.op == "call_method":
args_to_process = list(new_args[1:])
else:
args_to_process = list(new_args)
@@ -350,7 +350,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
args_to_process = tuple(args_to_process)
if node.op == 'call_method':
if node.op == "call_method":
new_args = (new_args[0],) + args_to_process
else:
new_args = args_to_process
@@ -358,9 +358,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
node.args = new_args
def _filter_node_with_shape_args(node):
if node.op == 'call_method':
if node.op == "call_method":
target = getattr(node.args[0]._meta_data.__class__, node.target)
elif node.op == 'call_function':
elif node.op == "call_function":
target = node.target
else:
target = None
@@ -371,7 +371,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
for node in nodes:
# skip the placeholder node added in _solution_annotation pass
if not hasattr(node, 'sharding_spec'):
if not hasattr(node, "sharding_spec"):
continue
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
@@ -392,15 +392,21 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
reduction_stream = torch.cuda.Stream()
def _add_hook_for_grad_communication(node, param, name=None):
comm_actions = node.best_strategy.communication_actions
def _filter_param_to_hook(node, op_data, comm_action, name):
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
if (
node.op == "call_module"
and op_data.type == OperationDataType.PARAM
and op_data.name == name
and comm_action.comm_type == CommType.HOOK
):
return True
if node.op == 'get_attr' and isinstance(
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
if (
node.op == "get_attr"
and isinstance(node._meta_data, torch.nn.parameter.Parameter)
and comm_action.comm_type == CommType.HOOK
):
return True
return False
@@ -410,7 +416,6 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
if _filter_param_to_hook(node, operation_data, comm_action, name=name):
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
if overlap:
with torch.cuda.stream(stream):
@@ -426,22 +431,26 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# apply the sharding spec of parameters
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
setattr(param, "sharding_spec", origin_sharding_spec)
# TODO: build a ColoParameter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
shape_consistency_manager.apply_for_autoparallel_runtime(
param.data, param.sharding_spec, target_sharding_spec
)
.detach()
.clone()
)
return param
for node in nodes:
if node.op == 'call_module':
if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
# TODO: we need to do more actions to take care of the shared parameters.
if hasattr(target_module, 'processed') and target_module.processed:
if hasattr(target_module, "processed") and target_module.processed:
continue
setattr(target_module, 'processed', True)
setattr(target_module, "processed", True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
param = _shard_param(param, target_sharding_spec)
@@ -453,7 +462,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# apply the sharding spec of buffers
for name, buffer in target_module.named_buffers():
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
setattr(buffer, 'sharding_spec', origin_sharding_spec)
setattr(buffer, "sharding_spec", origin_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
sharded_buffer_dict[name] = buffer_sharded
@@ -461,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
for name, buffer_sharded in sharded_buffer_dict.items():
setattr(target_module, name, buffer_sharded.detach().clone())
if node.op == 'get_attr':
if node.op == "get_attr":
root = node.graph.owning_module
atoms = node.target.split(".")
attr_len = len(atoms)
@@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
"""
replace the origin kernel into kernel with implicit communication inside.
"""
pass
def runtime_preparation_pass(gm: torch.fx.GraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap=False):
def runtime_preparation_pass(
gm: torch.fx.GraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap=False,
):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
gm, solution, strategies_constructor)
gm, solution, strategies_constructor
)
gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.