mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -14,18 +14,20 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
target_sharding_spec: ShardingSpec) -> ShardMetaInfo:
|
||||
def _construct_shard_meta_info(
|
||||
node: Node, origin_sharding_spec: ShardingSpec, target_sharding_spec: ShardingSpec
|
||||
) -> ShardMetaInfo:
|
||||
# get comm_action_sequence and total_cost from shape_consistency_manager
|
||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||
origin_sharding_spec, target_sharding_spec)
|
||||
origin_sharding_spec, target_sharding_spec
|
||||
)
|
||||
|
||||
meta_info = ShardMetaInfo()
|
||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||
# get mem cost for ShardMetaInfo
|
||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||
# extract user that has _meta_data and extract element length
|
||||
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
|
||||
input_node = next(n for n in node._input_nodes if hasattr(n, "_meta_data"))
|
||||
element_length = input_node._meta_data.element_size()
|
||||
|
||||
mem_cost.fwd.activation *= element_length
|
||||
@@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
meta_info.memory_cost = mem_cost
|
||||
|
||||
# get computation cost for ShardMetaInfo
|
||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||
total_cost['backward'] * element_length,
|
||||
total_cost['total'] * element_length)
|
||||
meta_info.compute_cost = TrainCycleItem(
|
||||
total_cost["forward"] * element_length,
|
||||
total_cost["backward"] * element_length,
|
||||
total_cost["total"] * element_length,
|
||||
)
|
||||
|
||||
# get tensor shape for ShardMetaInfo
|
||||
origin_sharding_spec: ShardingSpec
|
||||
@@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
|
||||
output_shape = target_sharding_spec.get_sharded_shape_per_device()
|
||||
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
|
||||
meta_info.fwd_buffer = []
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
|
||||
|
||||
return meta_info
|
||||
|
||||
@@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
|
||||
# extract node index and user node index
|
||||
args = node.args
|
||||
node_index, user_node_index = args[3], args[4]
|
||||
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
|
||||
user_node_index]
|
||||
origin_sharding_spec, target_sharding_spec = (
|
||||
origin_spec_dict[node_index],
|
||||
sharding_spec_dict[node_index][user_node_index],
|
||||
)
|
||||
|
||||
return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
@@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S
|
||||
# this case is for all_reduce, there will be no memory cost
|
||||
meta_info = ShardMetaInfo()
|
||||
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
|
||||
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
|
||||
output_node = next(n for n in node.users if hasattr(n, "_meta_data"))
|
||||
element_length = output_node._meta_data.element_size()
|
||||
|
||||
total_cost = comm_action.comm_spec.get_comm_cost()
|
||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||
total_cost['backward'] * element_length,
|
||||
total_cost['total'] * element_length)
|
||||
meta_info.compute_cost = TrainCycleItem(
|
||||
total_cost["forward"] * element_length,
|
||||
total_cost["backward"] * element_length,
|
||||
total_cost["total"] * element_length,
|
||||
)
|
||||
|
||||
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
|
||||
meta_info.fwd_buffer = []
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
|
||||
else:
|
||||
# this case will be handled by shape consistency manager
|
||||
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
|
||||
'tgt_spec']
|
||||
origin_sharding_spec, target_sharding_spec = (
|
||||
comm_action.comm_spec["src_spec"],
|
||||
comm_action.comm_spec["tgt_spec"],
|
||||
)
|
||||
meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
return meta_info
|
||||
|
||||
|
||||
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
|
||||
comm_actions_dict: Dict) -> GraphModule:
|
||||
def comm_metainfo_pass(
|
||||
gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, comm_actions_dict: Dict
|
||||
) -> GraphModule:
|
||||
"""
|
||||
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == runtime_apply:
|
||||
setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
||||
setattr(node, "best_strategy_info", _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
||||
elif node.target == runtime_comm_spec_apply:
|
||||
setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||
setattr(node, "best_strategy_info", _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||
else:
|
||||
pass
|
||||
return gm
|
||||
|
@@ -21,16 +21,15 @@ def _normalize_tuple(x):
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class MetaInfoProp:
|
||||
|
||||
def __init__(self, module: GraphModule) -> None:
|
||||
self.module = module
|
||||
self.func_dict = {
|
||||
'placeholder': self.placeholder_handler,
|
||||
'get_attr': self.get_attr_handler,
|
||||
'output': self.output_handler,
|
||||
'call_function': self.node_handler,
|
||||
'call_module': self.node_handler,
|
||||
'call_method': self.node_handler,
|
||||
"placeholder": self.placeholder_handler,
|
||||
"get_attr": self.get_attr_handler,
|
||||
"output": self.output_handler,
|
||||
"call_function": self.node_handler,
|
||||
"call_module": self.node_handler,
|
||||
"call_method": self.node_handler,
|
||||
}
|
||||
|
||||
def _set_data_ptr(self, x):
|
||||
@@ -46,7 +45,7 @@ class MetaInfoProp:
|
||||
"""
|
||||
Check if the node is inplace operation.
|
||||
"""
|
||||
if node.op == 'call_module':
|
||||
if node.op == "call_module":
|
||||
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
|
||||
elif node.op == "call_function":
|
||||
return node.target in OUTPUT_SAVED_OPS
|
||||
@@ -66,7 +65,7 @@ class MetaInfoProp:
|
||||
Handle the placeholder node.
|
||||
"""
|
||||
graph_info = GraphInfo()
|
||||
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
||||
out = _normalize_tuple(getattr(node, "_meta_data", None))
|
||||
graph_info.fwd_out = list(out) if out[0] is not None else []
|
||||
node.meta = {**asdict(graph_info)}
|
||||
|
||||
@@ -96,7 +95,7 @@ class MetaInfoProp:
|
||||
"""
|
||||
Handle other kind of nodes
|
||||
"""
|
||||
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
|
||||
assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}"
|
||||
graph_info = GraphInfo()
|
||||
meta_info = node.best_strategy_info
|
||||
meta_info: ShardMetaInfo
|
||||
@@ -126,7 +125,8 @@ class MetaInfoProp:
|
||||
for tensor in par.meta.get("fwd_out", []):
|
||||
tensor: torch.Tensor
|
||||
target_input_tensor = next(
|
||||
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
|
||||
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None
|
||||
)
|
||||
if target_input_tensor is not None:
|
||||
target_input_tensor.data_ptr = tensor.data_ptr
|
||||
|
||||
|
@@ -1,18 +1,10 @@
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
|
||||
from colossalai.tensor.comm_spec import CommSpec
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
@@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
|
||||
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
|
||||
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
|
||||
user_node_index: int):
|
||||
def runtime_apply_for_iterable_object(
|
||||
node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int
|
||||
):
|
||||
"""
|
||||
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
|
||||
is converted into the user node expected form.
|
||||
"""
|
||||
rst = []
|
||||
for index, (origin_sharding_spec,
|
||||
target_sharding_spec) in enumerate(zip(origin_dict[node_index],
|
||||
input_dict[node_index][user_node_index])):
|
||||
for index, (origin_sharding_spec, target_sharding_spec) in enumerate(
|
||||
zip(origin_dict[node_index], input_dict[node_index][user_node_index])
|
||||
):
|
||||
rst.append(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
|
||||
target_sharding_spec))
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
node[index], origin_sharding_spec, target_sharding_spec
|
||||
)
|
||||
)
|
||||
rst = type(node)(rst)
|
||||
return rst
|
||||
|
||||
@@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_
|
||||
if isinstance(comm_action.comm_spec, CommSpec):
|
||||
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
|
||||
else:
|
||||
origin_sharding_spec = comm_action.comm_spec['src_spec']
|
||||
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
|
||||
origin_sharding_spec = comm_action.comm_spec["src_spec"]
|
||||
tgt_sharding_spec = comm_action.comm_spec["tgt_spec"]
|
||||
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
|
||||
return rst
|
||||
|
||||
@@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]):
|
||||
node_to_index_dict = {}
|
||||
index = 0
|
||||
for node in nodes:
|
||||
if node.target == 'sharding_spec_convert_dict':
|
||||
if node.target == "sharding_spec_convert_dict":
|
||||
input_dict_node = node
|
||||
continue
|
||||
if node.target == 'origin_node_sharding_spec_dict':
|
||||
if node.target == "origin_node_sharding_spec_dict":
|
||||
origin_dict_node = node
|
||||
continue
|
||||
if node.target == 'comm_actions_dict':
|
||||
if node.target == "comm_actions_dict":
|
||||
comm_actions_dict_node = node
|
||||
continue
|
||||
if not hasattr(node, 'best_strategy'):
|
||||
if not hasattr(node, "best_strategy"):
|
||||
continue
|
||||
node_to_index_dict[node] = index
|
||||
index += 1
|
||||
@@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||
input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes)
|
||||
|
||||
for node in nodes:
|
||||
if not hasattr(node, 'best_strategy') or node.op == 'output':
|
||||
if not hasattr(node, "best_strategy") or node.op == "output":
|
||||
continue
|
||||
|
||||
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
|
||||
if isinstance(node.sharding_spec, (list, tuple)):
|
||||
assert isinstance(
|
||||
node.target_sharding_specs,
|
||||
(list,
|
||||
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
|
||||
node.target_sharding_specs, (list, tuple)
|
||||
), "target sharding specs should be tuple or list when node.sharding_spec is tuple or list"
|
||||
total_difference = 0
|
||||
for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
|
||||
node.target_sharding_specs[user_node_index]):
|
||||
for sharding_spec, target_sharding_spec in zip(
|
||||
node.sharding_spec, node.target_sharding_specs[user_node_index]
|
||||
):
|
||||
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
|
||||
if total_difference == 0:
|
||||
continue
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply_for_iterable_object,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
shape_consistency_node = mod_graph.create_node(
|
||||
"call_function",
|
||||
runtime_apply_for_iterable_object,
|
||||
args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
|
||||
)
|
||||
|
||||
else:
|
||||
assert isinstance(node.sharding_spec,
|
||||
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
|
||||
assert isinstance(
|
||||
node.sharding_spec, ShardingSpec
|
||||
), "node.sharding_spec should be type of ShardingSpec, tuple or list."
|
||||
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
|
||||
continue
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
if hasattr(user_node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(shape_consistency_node,
|
||||
mod_dir=user_node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
|
||||
shape_consistency_node = mod_graph.create_node(
|
||||
"call_function",
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
|
||||
)
|
||||
if hasattr(user_node.meta["info"], "activation_checkpoint"):
|
||||
MetaInfo(
|
||||
shape_consistency_node,
|
||||
mod_dir=user_node.meta["info"].mod_dir,
|
||||
activation_checkpoint=tuple(user_node.meta["info"].activation_checkpoint),
|
||||
)
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
@@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
_, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes)
|
||||
|
||||
for node in nodes:
|
||||
if not hasattr(node, 'best_strategy') or node.op == 'output':
|
||||
if not hasattr(node, "best_strategy") or node.op == "output":
|
||||
continue
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
for op_data, comm_action in comm_actions.items():
|
||||
|
||||
if comm_action.comm_type == CommType.HOOK:
|
||||
continue
|
||||
if comm_action.comm_type == CommType.BEFORE:
|
||||
@@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
else:
|
||||
comm_object = node.args[comm_action.arg_index]
|
||||
with mod_graph.inserting_before(node):
|
||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||
runtime_comm_spec_apply,
|
||||
args=(comm_object, comm_actions_dict_node,
|
||||
node_to_index_dict[node], op_data.name))
|
||||
comm_spec_apply_node = mod_graph.create_node(
|
||||
"call_function",
|
||||
runtime_comm_spec_apply,
|
||||
args=(comm_object, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
|
||||
)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if comm_action.key_for_kwarg is not None:
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
@@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
|
||||
elif comm_action.comm_type == CommType.AFTER:
|
||||
with mod_graph.inserting_after(node):
|
||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||
runtime_comm_spec_apply,
|
||||
args=(node, comm_actions_dict_node,
|
||||
node_to_index_dict[node], op_data.name))
|
||||
comm_spec_apply_node = mod_graph.create_node(
|
||||
"call_function",
|
||||
runtime_comm_spec_apply,
|
||||
args=(node, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
|
||||
)
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
if user == comm_spec_apply_node:
|
||||
@@ -211,10 +212,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs[str(node)] = comm_spec_apply_node
|
||||
user.kwargs = new_kwargs
|
||||
if hasattr(node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(comm_spec_apply_node,
|
||||
mod_dir=node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
|
||||
if hasattr(node.meta["info"], "activation_checkpoint"):
|
||||
MetaInfo(
|
||||
comm_spec_apply_node,
|
||||
mod_dir=node.meta["info"].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
|
||||
)
|
||||
|
||||
return gm
|
||||
|
||||
@@ -227,21 +230,21 @@ def _act_annotation_pass(gm: torch.fx.GraphModule):
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
|
||||
for node in nodes:
|
||||
if not hasattr(node.meta, 'activation_checkpoint'):
|
||||
from .runtime_preparation_pass import size_processing
|
||||
if not hasattr(node.meta, "activation_checkpoint"):
|
||||
pass
|
||||
|
||||
user_act_annotation = -1
|
||||
input_act_annotation = -1
|
||||
for user_node in node.users.keys():
|
||||
if 'activation_checkpoint' in user_node.meta:
|
||||
user_act_annotation = user_node.meta['activation_checkpoint']
|
||||
if "activation_checkpoint" in user_node.meta:
|
||||
user_act_annotation = user_node.meta["activation_checkpoint"]
|
||||
break
|
||||
for input_node in node._input_nodes.keys():
|
||||
if 'activation_checkpoint' in input_node.meta:
|
||||
input_act_annotation = input_node.meta['activation_checkpoint']
|
||||
if "activation_checkpoint" in input_node.meta:
|
||||
input_act_annotation = input_node.meta["activation_checkpoint"]
|
||||
break
|
||||
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
|
||||
node.meta['activation_checkpoint'] = user_act_annotation
|
||||
node.meta["activation_checkpoint"] = user_act_annotation
|
||||
|
||||
return gm
|
||||
|
||||
|
@@ -1,19 +1,12 @@
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
|
||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.comm_spec import _all_reduce
|
||||
@@ -25,11 +18,13 @@ from .constants import SHAPE_ARGUMENT_OPS
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def size_processing(size: Union[int, torch.Size],
|
||||
dim_partition_dict: Dict[int, List[int]],
|
||||
device_mesh_info: Dict[int, int],
|
||||
target_dim: int = None,
|
||||
node_name: str = None):
|
||||
def size_processing(
|
||||
size: Union[int, torch.Size],
|
||||
dim_partition_dict: Dict[int, List[int]],
|
||||
device_mesh_info: Dict[int, int],
|
||||
target_dim: int = None,
|
||||
node_name: str = None,
|
||||
):
|
||||
"""
|
||||
This method will be invoked during runtime to convert size node value depending on distributed information.
|
||||
"""
|
||||
@@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size],
|
||||
return size
|
||||
|
||||
|
||||
def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||
strategies_constructor: StrategiesConstructor):
|
||||
def solution_annotation_pass(
|
||||
gm: torch.fx.GraphModule, solution: List[int], strategies_constructor: StrategiesConstructor
|
||||
):
|
||||
"""
|
||||
This method is used to stick the solution strategy to the nodes and add the information
|
||||
required in runtime into graph as placeholder nodes.
|
||||
@@ -70,14 +66,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
|
||||
strategies_vector = node.strategies_vector
|
||||
# stick the solution strategy to the corresponding node
|
||||
setattr(node, 'best_strategy', strategies_vector[strategy_index])
|
||||
setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
|
||||
setattr(node, "best_strategy", strategies_vector[strategy_index])
|
||||
setattr(node, "sharding_spec", strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
|
||||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||
str(node))
|
||||
str(node)
|
||||
)
|
||||
|
||||
# attach the corresponding metainfo if node has the attribute `strategies_info`
|
||||
if hasattr(node, 'strategies_info'):
|
||||
setattr(node, 'best_strategy_info', node.strategies_info[strategy_index])
|
||||
if hasattr(node, "strategies_info"):
|
||||
setattr(node, "best_strategy_info", node.strategies_info[strategy_index])
|
||||
|
||||
# the dict to get input sharding specs of user node
|
||||
sharding_spec_convert_dict = {}
|
||||
@@ -92,15 +89,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||
target_sharding_specs.append(target_sharding_spec)
|
||||
sharding_spec_convert_dict[index] = target_sharding_specs
|
||||
setattr(node, 'target_sharding_specs', target_sharding_specs)
|
||||
setattr(node, "target_sharding_specs", target_sharding_specs)
|
||||
|
||||
# the get_attr node strategy is kind of pending strategy, which means we will change it
|
||||
# to the same strategy of the user node.
|
||||
if node.op == 'get_attr':
|
||||
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
|
||||
if node.op == "get_attr":
|
||||
assert len(target_sharding_specs) == 1, f"sharing weight is not supported in current version."
|
||||
target_node = node.strategies_vector.successor_nodes[0]
|
||||
node_name = str(node)
|
||||
if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP:
|
||||
if target_node.op == "call_function" and target_node.target in RESHAPE_FUNC_OP:
|
||||
node_name = str(target_node)
|
||||
target_node = target_node.strategies_vector.successor_nodes[0]
|
||||
user_strategy = target_node.best_strategy
|
||||
@@ -122,11 +119,11 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||
|
||||
# add above dicts into graph
|
||||
for node in nodes:
|
||||
if node.op != 'placeholder':
|
||||
if node.op != "placeholder":
|
||||
with mod_graph.inserting_before(node):
|
||||
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
|
||||
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
|
||||
comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
|
||||
input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
|
||||
origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
|
||||
comm_actions_dict_node = mod_graph.create_node("placeholder", target="comm_actions_dict")
|
||||
break
|
||||
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
||||
|
||||
@@ -148,7 +145,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
device_mesh_info[dim] = dim_size
|
||||
|
||||
def _extract_target_dim(node):
|
||||
'''
|
||||
"""
|
||||
A helper function to extract the target dimension from size node.
|
||||
There are two usages of torch.Tensor.size:
|
||||
1. tensor.size()
|
||||
@@ -156,7 +153,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
|
||||
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
|
||||
Otherwise, the output will be in type of torch.Size and this function will return None.
|
||||
'''
|
||||
"""
|
||||
target_dim = None
|
||||
if len(node.args) > 1:
|
||||
target_dim = node.args[1]
|
||||
@@ -165,19 +162,21 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
return target_dim
|
||||
|
||||
def _post_processing(node, size_processing_node):
|
||||
'''
|
||||
"""
|
||||
This function is used to process the dependency between the size node and its users after
|
||||
inserting the size_process_node.
|
||||
'''
|
||||
"""
|
||||
# store original node and processing node pair in node_pairs dictionary
|
||||
# It will be used to replace the original node with processing node in slice object
|
||||
node_pairs[node] = size_processing_node
|
||||
size_processing_node._meta_data = node._meta_data
|
||||
|
||||
if hasattr(node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(size_processing_node,
|
||||
mod_dir=node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
|
||||
if hasattr(node.meta["info"], "activation_checkpoint"):
|
||||
MetaInfo(
|
||||
size_processing_node,
|
||||
mod_dir=node.meta["info"].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
|
||||
)
|
||||
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
@@ -196,10 +195,10 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
def _update_slice_object_args(slice_object):
|
||||
'''
|
||||
"""
|
||||
This function is used to update the slice object argument list.
|
||||
If the slice object contains the Node argument, then the size node will be replaced with
|
||||
'''
|
||||
"""
|
||||
if isinstance(slice_object, slice):
|
||||
start = slice_object.start
|
||||
stop = slice_object.stop
|
||||
@@ -220,8 +219,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
|
||||
|
||||
for node in nodes:
|
||||
|
||||
if node.op == 'call_method' and node.target == 'size':
|
||||
if node.op == "call_method" and node.target == "size":
|
||||
# extract useful information from size node
|
||||
# dim_partition_dict will instruct the size value on which
|
||||
# dimension should be enlarged.
|
||||
@@ -232,14 +230,14 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
|
||||
# insert size_processing node
|
||||
with mod_graph.inserting_after(node):
|
||||
size_processing_node = mod_graph.create_node('call_function',
|
||||
size_processing,
|
||||
args=(node, dim_partition_dict, device_mesh_info,
|
||||
target_dim, node.name))
|
||||
size_processing_node = mod_graph.create_node(
|
||||
"call_function",
|
||||
size_processing,
|
||||
args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name),
|
||||
)
|
||||
_post_processing(node, size_processing_node)
|
||||
|
||||
if node.op == 'call_function' and node.target == operator.getitem:
|
||||
|
||||
if node.op == "call_function" and node.target == operator.getitem:
|
||||
getitem_index = node.args[1]
|
||||
# slice object is quite special in torch.fx graph,
|
||||
# On one side, we treat slice object same as type of int,
|
||||
@@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
|
||||
def _extract_info_from_sharding_spec(sharding_spec):
|
||||
'''
|
||||
"""
|
||||
This function is used to extract the dim_partition_dict and device_mesh from
|
||||
sharding spec instance or a list of sharding spec.
|
||||
'''
|
||||
"""
|
||||
if isinstance(sharding_spec, ShardingSpec):
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
device_mesh = sharding_spec.device_mesh
|
||||
return dim_partition_dict, device_mesh
|
||||
if sharding_spec is None:
|
||||
return None, None
|
||||
assert isinstance(sharding_spec,
|
||||
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
|
||||
assert isinstance(
|
||||
sharding_spec, (tuple, list)
|
||||
), "sharding_spec should be type of ShardingSpec, tuple, list or None"
|
||||
|
||||
device_mesh = sharding_spec[0].device_mesh
|
||||
dim_partition_dict = []
|
||||
@@ -322,8 +321,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(arg,
|
||||
(int, tuple, list)), 'The argument in view node should be either type of Node or int.'
|
||||
assert isinstance(
|
||||
arg, (int, tuple, list)
|
||||
), "The argument in view node should be either type of Node or int."
|
||||
if isinstance(arg, (tuple, list)):
|
||||
new_args.extend(arg)
|
||||
else:
|
||||
@@ -332,7 +332,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
|
||||
|
||||
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
|
||||
new_args = _process_node_arguments(node)
|
||||
if node.op == 'call_method':
|
||||
if node.op == "call_method":
|
||||
args_to_process = list(new_args[1:])
|
||||
else:
|
||||
args_to_process = list(new_args)
|
||||
@@ -350,7 +350,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
|
||||
|
||||
args_to_process = tuple(args_to_process)
|
||||
|
||||
if node.op == 'call_method':
|
||||
if node.op == "call_method":
|
||||
new_args = (new_args[0],) + args_to_process
|
||||
else:
|
||||
new_args = args_to_process
|
||||
@@ -358,9 +358,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
|
||||
node.args = new_args
|
||||
|
||||
def _filter_node_with_shape_args(node):
|
||||
if node.op == 'call_method':
|
||||
if node.op == "call_method":
|
||||
target = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
elif node.op == 'call_function':
|
||||
elif node.op == "call_function":
|
||||
target = node.target
|
||||
else:
|
||||
target = None
|
||||
@@ -371,7 +371,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
|
||||
|
||||
for node in nodes:
|
||||
# skip the placeholder node added in _solution_annotation pass
|
||||
if not hasattr(node, 'sharding_spec'):
|
||||
if not hasattr(node, "sharding_spec"):
|
||||
continue
|
||||
|
||||
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
|
||||
@@ -392,15 +392,21 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
reduction_stream = torch.cuda.Stream()
|
||||
|
||||
def _add_hook_for_grad_communication(node, param, name=None):
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
|
||||
def _filter_param_to_hook(node, op_data, comm_action, name):
|
||||
|
||||
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||
if (
|
||||
node.op == "call_module"
|
||||
and op_data.type == OperationDataType.PARAM
|
||||
and op_data.name == name
|
||||
and comm_action.comm_type == CommType.HOOK
|
||||
):
|
||||
return True
|
||||
if node.op == 'get_attr' and isinstance(
|
||||
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||
if (
|
||||
node.op == "get_attr"
|
||||
and isinstance(node._meta_data, torch.nn.parameter.Parameter)
|
||||
and comm_action.comm_type == CommType.HOOK
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -410,7 +416,6 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
if _filter_param_to_hook(node, operation_data, comm_action, name=name):
|
||||
|
||||
def wrapper(param, comm_spec, stream, overlap):
|
||||
|
||||
def hook_fn(grad):
|
||||
if overlap:
|
||||
with torch.cuda.stream(stream):
|
||||
@@ -426,22 +431,26 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
# apply the sharding spec of parameters
|
||||
if target_sharding_spec.dim_partition_dict != {}:
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||
setattr(param, "sharding_spec", origin_sharding_spec)
|
||||
# TODO: build a ColoParameter class to manager the distributed parameters
|
||||
# we could use .data here, because all the operations just happen before the real training
|
||||
# loop, so we don't need to track these operations in the autograd graph.
|
||||
param = torch.nn.Parameter(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
||||
target_sharding_spec).detach().clone())
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
param.data, param.sharding_spec, target_sharding_spec
|
||||
)
|
||||
.detach()
|
||||
.clone()
|
||||
)
|
||||
return param
|
||||
|
||||
for node in nodes:
|
||||
if node.op == 'call_module':
|
||||
if node.op == "call_module":
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
# TODO: we need to do more actions to take care of the shared parameters.
|
||||
if hasattr(target_module, 'processed') and target_module.processed:
|
||||
if hasattr(target_module, "processed") and target_module.processed:
|
||||
continue
|
||||
setattr(target_module, 'processed', True)
|
||||
setattr(target_module, "processed", True)
|
||||
for name, param in target_module.named_parameters():
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
param = _shard_param(param, target_sharding_spec)
|
||||
@@ -453,7 +462,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
# apply the sharding spec of buffers
|
||||
for name, buffer in target_module.named_buffers():
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
|
||||
setattr(buffer, 'sharding_spec', origin_sharding_spec)
|
||||
setattr(buffer, "sharding_spec", origin_sharding_spec)
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
|
||||
sharded_buffer_dict[name] = buffer_sharded
|
||||
@@ -461,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
for name, buffer_sharded in sharded_buffer_dict.items():
|
||||
setattr(target_module, name, buffer_sharded.detach().clone())
|
||||
|
||||
if node.op == 'get_attr':
|
||||
if node.op == "get_attr":
|
||||
root = node.graph.owning_module
|
||||
atoms = node.target.split(".")
|
||||
attr_len = len(atoms)
|
||||
@@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
replace the origin kernel into kernel with implicit communication inside.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
||||
solution: List[int],
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
overlap=False):
|
||||
def runtime_preparation_pass(
|
||||
gm: torch.fx.GraphModule,
|
||||
solution: List[int],
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
overlap=False,
|
||||
):
|
||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
|
||||
gm, solution, strategies_constructor)
|
||||
gm, solution, strategies_constructor
|
||||
)
|
||||
gm = size_value_converting_pass(gm, device_mesh)
|
||||
gm = node_args_converting_pass(gm, device_mesh)
|
||||
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
||||
|
Reference in New Issue
Block a user