mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,19 +1,12 @@
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
|
||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.comm_spec import _all_reduce
|
||||
@@ -25,11 +18,13 @@ from .constants import SHAPE_ARGUMENT_OPS
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def size_processing(size: Union[int, torch.Size],
|
||||
dim_partition_dict: Dict[int, List[int]],
|
||||
device_mesh_info: Dict[int, int],
|
||||
target_dim: int = None,
|
||||
node_name: str = None):
|
||||
def size_processing(
|
||||
size: Union[int, torch.Size],
|
||||
dim_partition_dict: Dict[int, List[int]],
|
||||
device_mesh_info: Dict[int, int],
|
||||
target_dim: int = None,
|
||||
node_name: str = None,
|
||||
):
|
||||
"""
|
||||
This method will be invoked during runtime to convert size node value depending on distributed information.
|
||||
"""
|
||||
@@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size],
|
||||
return size
|
||||
|
||||
|
||||
def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||
strategies_constructor: StrategiesConstructor):
|
||||
def solution_annotation_pass(
|
||||
gm: torch.fx.GraphModule, solution: List[int], strategies_constructor: StrategiesConstructor
|
||||
):
|
||||
"""
|
||||
This method is used to stick the solution strategy to the nodes and add the information
|
||||
required in runtime into graph as placeholder nodes.
|
||||
@@ -70,14 +66,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
|
||||
strategies_vector = node.strategies_vector
|
||||
# stick the solution strategy to the corresponding node
|
||||
setattr(node, 'best_strategy', strategies_vector[strategy_index])
|
||||
setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
|
||||
setattr(node, "best_strategy", strategies_vector[strategy_index])
|
||||
setattr(node, "sharding_spec", strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
|
||||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||
str(node))
|
||||
str(node)
|
||||
)
|
||||
|
||||
# attach the corresponding metainfo if node has the attribute `strategies_info`
|
||||
if hasattr(node, 'strategies_info'):
|
||||
setattr(node, 'best_strategy_info', node.strategies_info[strategy_index])
|
||||
if hasattr(node, "strategies_info"):
|
||||
setattr(node, "best_strategy_info", node.strategies_info[strategy_index])
|
||||
|
||||
# the dict to get input sharding specs of user node
|
||||
sharding_spec_convert_dict = {}
|
||||
@@ -92,15 +89,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||
target_sharding_specs.append(target_sharding_spec)
|
||||
sharding_spec_convert_dict[index] = target_sharding_specs
|
||||
setattr(node, 'target_sharding_specs', target_sharding_specs)
|
||||
setattr(node, "target_sharding_specs", target_sharding_specs)
|
||||
|
||||
# the get_attr node strategy is kind of pending strategy, which means we will change it
|
||||
# to the same strategy of the user node.
|
||||
if node.op == 'get_attr':
|
||||
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
|
||||
if node.op == "get_attr":
|
||||
assert len(target_sharding_specs) == 1, f"sharing weight is not supported in current version."
|
||||
target_node = node.strategies_vector.successor_nodes[0]
|
||||
node_name = str(node)
|
||||
if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP:
|
||||
if target_node.op == "call_function" and target_node.target in RESHAPE_FUNC_OP:
|
||||
node_name = str(target_node)
|
||||
target_node = target_node.strategies_vector.successor_nodes[0]
|
||||
user_strategy = target_node.best_strategy
|
||||
@@ -122,11 +119,11 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||
|
||||
# add above dicts into graph
|
||||
for node in nodes:
|
||||
if node.op != 'placeholder':
|
||||
if node.op != "placeholder":
|
||||
with mod_graph.inserting_before(node):
|
||||
input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
|
||||
origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
|
||||
comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
|
||||
input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
|
||||
origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
|
||||
comm_actions_dict_node = mod_graph.create_node("placeholder", target="comm_actions_dict")
|
||||
break
|
||||
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
||||
|
||||
@@ -148,7 +145,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
device_mesh_info[dim] = dim_size
|
||||
|
||||
def _extract_target_dim(node):
|
||||
'''
|
||||
"""
|
||||
A helper function to extract the target dimension from size node.
|
||||
There are two usages of torch.Tensor.size:
|
||||
1. tensor.size()
|
||||
@@ -156,7 +153,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
|
||||
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
|
||||
Otherwise, the output will be in type of torch.Size and this function will return None.
|
||||
'''
|
||||
"""
|
||||
target_dim = None
|
||||
if len(node.args) > 1:
|
||||
target_dim = node.args[1]
|
||||
@@ -165,19 +162,21 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
return target_dim
|
||||
|
||||
def _post_processing(node, size_processing_node):
|
||||
'''
|
||||
"""
|
||||
This function is used to process the dependency between the size node and its users after
|
||||
inserting the size_process_node.
|
||||
'''
|
||||
"""
|
||||
# store original node and processing node pair in node_pairs dictionary
|
||||
# It will be used to replace the original node with processing node in slice object
|
||||
node_pairs[node] = size_processing_node
|
||||
size_processing_node._meta_data = node._meta_data
|
||||
|
||||
if hasattr(node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(size_processing_node,
|
||||
mod_dir=node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
|
||||
if hasattr(node.meta["info"], "activation_checkpoint"):
|
||||
MetaInfo(
|
||||
size_processing_node,
|
||||
mod_dir=node.meta["info"].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
|
||||
)
|
||||
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
@@ -196,10 +195,10 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
def _update_slice_object_args(slice_object):
|
||||
'''
|
||||
"""
|
||||
This function is used to update the slice object argument list.
|
||||
If the slice object contains the Node argument, then the size node will be replaced with
|
||||
'''
|
||||
"""
|
||||
if isinstance(slice_object, slice):
|
||||
start = slice_object.start
|
||||
stop = slice_object.stop
|
||||
@@ -220,8 +219,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
|
||||
|
||||
for node in nodes:
|
||||
|
||||
if node.op == 'call_method' and node.target == 'size':
|
||||
if node.op == "call_method" and node.target == "size":
|
||||
# extract useful information from size node
|
||||
# dim_partition_dict will instruct the size value on which
|
||||
# dimension should be enlarged.
|
||||
@@ -232,14 +230,14 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
|
||||
# insert size_processing node
|
||||
with mod_graph.inserting_after(node):
|
||||
size_processing_node = mod_graph.create_node('call_function',
|
||||
size_processing,
|
||||
args=(node, dim_partition_dict, device_mesh_info,
|
||||
target_dim, node.name))
|
||||
size_processing_node = mod_graph.create_node(
|
||||
"call_function",
|
||||
size_processing,
|
||||
args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name),
|
||||
)
|
||||
_post_processing(node, size_processing_node)
|
||||
|
||||
if node.op == 'call_function' and node.target == operator.getitem:
|
||||
|
||||
if node.op == "call_function" and node.target == operator.getitem:
|
||||
getitem_index = node.args[1]
|
||||
# slice object is quite special in torch.fx graph,
|
||||
# On one side, we treat slice object same as type of int,
|
||||
@@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
|
||||
def _extract_info_from_sharding_spec(sharding_spec):
|
||||
'''
|
||||
"""
|
||||
This function is used to extract the dim_partition_dict and device_mesh from
|
||||
sharding spec instance or a list of sharding spec.
|
||||
'''
|
||||
"""
|
||||
if isinstance(sharding_spec, ShardingSpec):
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
device_mesh = sharding_spec.device_mesh
|
||||
return dim_partition_dict, device_mesh
|
||||
if sharding_spec is None:
|
||||
return None, None
|
||||
assert isinstance(sharding_spec,
|
||||
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
|
||||
assert isinstance(
|
||||
sharding_spec, (tuple, list)
|
||||
), "sharding_spec should be type of ShardingSpec, tuple, list or None"
|
||||
|
||||
device_mesh = sharding_spec[0].device_mesh
|
||||
dim_partition_dict = []
|
||||
@@ -322,8 +321,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(arg,
|
||||
(int, tuple, list)), 'The argument in view node should be either type of Node or int.'
|
||||
assert isinstance(
|
||||
arg, (int, tuple, list)
|
||||
), "The argument in view node should be either type of Node or int."
|
||||
if isinstance(arg, (tuple, list)):
|
||||
new_args.extend(arg)
|
||||
else:
|
||||
@@ -332,7 +332,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
|
||||
|
||||
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
|
||||
new_args = _process_node_arguments(node)
|
||||
if node.op == 'call_method':
|
||||
if node.op == "call_method":
|
||||
args_to_process = list(new_args[1:])
|
||||
else:
|
||||
args_to_process = list(new_args)
|
||||
@@ -350,7 +350,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
|
||||
|
||||
args_to_process = tuple(args_to_process)
|
||||
|
||||
if node.op == 'call_method':
|
||||
if node.op == "call_method":
|
||||
new_args = (new_args[0],) + args_to_process
|
||||
else:
|
||||
new_args = args_to_process
|
||||
@@ -358,9 +358,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
|
||||
node.args = new_args
|
||||
|
||||
def _filter_node_with_shape_args(node):
|
||||
if node.op == 'call_method':
|
||||
if node.op == "call_method":
|
||||
target = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
elif node.op == 'call_function':
|
||||
elif node.op == "call_function":
|
||||
target = node.target
|
||||
else:
|
||||
target = None
|
||||
@@ -371,7 +371,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
|
||||
|
||||
for node in nodes:
|
||||
# skip the placeholder node added in _solution_annotation pass
|
||||
if not hasattr(node, 'sharding_spec'):
|
||||
if not hasattr(node, "sharding_spec"):
|
||||
continue
|
||||
|
||||
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
|
||||
@@ -392,15 +392,21 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
reduction_stream = torch.cuda.Stream()
|
||||
|
||||
def _add_hook_for_grad_communication(node, param, name=None):
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
|
||||
def _filter_param_to_hook(node, op_data, comm_action, name):
|
||||
|
||||
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||
if (
|
||||
node.op == "call_module"
|
||||
and op_data.type == OperationDataType.PARAM
|
||||
and op_data.name == name
|
||||
and comm_action.comm_type == CommType.HOOK
|
||||
):
|
||||
return True
|
||||
if node.op == 'get_attr' and isinstance(
|
||||
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||
if (
|
||||
node.op == "get_attr"
|
||||
and isinstance(node._meta_data, torch.nn.parameter.Parameter)
|
||||
and comm_action.comm_type == CommType.HOOK
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -410,7 +416,6 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
if _filter_param_to_hook(node, operation_data, comm_action, name=name):
|
||||
|
||||
def wrapper(param, comm_spec, stream, overlap):
|
||||
|
||||
def hook_fn(grad):
|
||||
if overlap:
|
||||
with torch.cuda.stream(stream):
|
||||
@@ -426,22 +431,26 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
# apply the sharding spec of parameters
|
||||
if target_sharding_spec.dim_partition_dict != {}:
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||
setattr(param, "sharding_spec", origin_sharding_spec)
|
||||
# TODO: build a ColoParameter class to manager the distributed parameters
|
||||
# we could use .data here, because all the operations just happen before the real training
|
||||
# loop, so we don't need to track these operations in the autograd graph.
|
||||
param = torch.nn.Parameter(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
||||
target_sharding_spec).detach().clone())
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
param.data, param.sharding_spec, target_sharding_spec
|
||||
)
|
||||
.detach()
|
||||
.clone()
|
||||
)
|
||||
return param
|
||||
|
||||
for node in nodes:
|
||||
if node.op == 'call_module':
|
||||
if node.op == "call_module":
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
# TODO: we need to do more actions to take care of the shared parameters.
|
||||
if hasattr(target_module, 'processed') and target_module.processed:
|
||||
if hasattr(target_module, "processed") and target_module.processed:
|
||||
continue
|
||||
setattr(target_module, 'processed', True)
|
||||
setattr(target_module, "processed", True)
|
||||
for name, param in target_module.named_parameters():
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
param = _shard_param(param, target_sharding_spec)
|
||||
@@ -453,7 +462,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
# apply the sharding spec of buffers
|
||||
for name, buffer in target_module.named_buffers():
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
|
||||
setattr(buffer, 'sharding_spec', origin_sharding_spec)
|
||||
setattr(buffer, "sharding_spec", origin_sharding_spec)
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
|
||||
sharded_buffer_dict[name] = buffer_sharded
|
||||
@@ -461,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
for name, buffer_sharded in sharded_buffer_dict.items():
|
||||
setattr(target_module, name, buffer_sharded.detach().clone())
|
||||
|
||||
if node.op == 'get_attr':
|
||||
if node.op == "get_attr":
|
||||
root = node.graph.owning_module
|
||||
atoms = node.target.split(".")
|
||||
attr_len = len(atoms)
|
||||
@@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
replace the origin kernel into kernel with implicit communication inside.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
||||
solution: List[int],
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
overlap=False):
|
||||
def runtime_preparation_pass(
|
||||
gm: torch.fx.GraphModule,
|
||||
solution: List[int],
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
overlap=False,
|
||||
):
|
||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
|
||||
gm, solution, strategies_constructor)
|
||||
gm, solution, strategies_constructor
|
||||
)
|
||||
gm = size_value_converting_pass(gm, device_mesh)
|
||||
gm = node_args_converting_pass(gm, device_mesh)
|
||||
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
||||
|
Reference in New Issue
Block a user