mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,18 +1,10 @@
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
|
||||
from colossalai.tensor.comm_spec import CommSpec
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
@@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
|
||||
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
|
||||
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
|
||||
user_node_index: int):
|
||||
def runtime_apply_for_iterable_object(
|
||||
node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int
|
||||
):
|
||||
"""
|
||||
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
|
||||
is converted into the user node expected form.
|
||||
"""
|
||||
rst = []
|
||||
for index, (origin_sharding_spec,
|
||||
target_sharding_spec) in enumerate(zip(origin_dict[node_index],
|
||||
input_dict[node_index][user_node_index])):
|
||||
for index, (origin_sharding_spec, target_sharding_spec) in enumerate(
|
||||
zip(origin_dict[node_index], input_dict[node_index][user_node_index])
|
||||
):
|
||||
rst.append(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
|
||||
target_sharding_spec))
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
node[index], origin_sharding_spec, target_sharding_spec
|
||||
)
|
||||
)
|
||||
rst = type(node)(rst)
|
||||
return rst
|
||||
|
||||
@@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_
|
||||
if isinstance(comm_action.comm_spec, CommSpec):
|
||||
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
|
||||
else:
|
||||
origin_sharding_spec = comm_action.comm_spec['src_spec']
|
||||
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
|
||||
origin_sharding_spec = comm_action.comm_spec["src_spec"]
|
||||
tgt_sharding_spec = comm_action.comm_spec["tgt_spec"]
|
||||
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
|
||||
return rst
|
||||
|
||||
@@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]):
|
||||
node_to_index_dict = {}
|
||||
index = 0
|
||||
for node in nodes:
|
||||
if node.target == 'sharding_spec_convert_dict':
|
||||
if node.target == "sharding_spec_convert_dict":
|
||||
input_dict_node = node
|
||||
continue
|
||||
if node.target == 'origin_node_sharding_spec_dict':
|
||||
if node.target == "origin_node_sharding_spec_dict":
|
||||
origin_dict_node = node
|
||||
continue
|
||||
if node.target == 'comm_actions_dict':
|
||||
if node.target == "comm_actions_dict":
|
||||
comm_actions_dict_node = node
|
||||
continue
|
||||
if not hasattr(node, 'best_strategy'):
|
||||
if not hasattr(node, "best_strategy"):
|
||||
continue
|
||||
node_to_index_dict[node] = index
|
||||
index += 1
|
||||
@@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||
input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes)
|
||||
|
||||
for node in nodes:
|
||||
if not hasattr(node, 'best_strategy') or node.op == 'output':
|
||||
if not hasattr(node, "best_strategy") or node.op == "output":
|
||||
continue
|
||||
|
||||
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
|
||||
if isinstance(node.sharding_spec, (list, tuple)):
|
||||
assert isinstance(
|
||||
node.target_sharding_specs,
|
||||
(list,
|
||||
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
|
||||
node.target_sharding_specs, (list, tuple)
|
||||
), "target sharding specs should be tuple or list when node.sharding_spec is tuple or list"
|
||||
total_difference = 0
|
||||
for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
|
||||
node.target_sharding_specs[user_node_index]):
|
||||
for sharding_spec, target_sharding_spec in zip(
|
||||
node.sharding_spec, node.target_sharding_specs[user_node_index]
|
||||
):
|
||||
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
|
||||
if total_difference == 0:
|
||||
continue
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply_for_iterable_object,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
shape_consistency_node = mod_graph.create_node(
|
||||
"call_function",
|
||||
runtime_apply_for_iterable_object,
|
||||
args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
|
||||
)
|
||||
|
||||
else:
|
||||
assert isinstance(node.sharding_spec,
|
||||
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
|
||||
assert isinstance(
|
||||
node.sharding_spec, ShardingSpec
|
||||
), "node.sharding_spec should be type of ShardingSpec, tuple or list."
|
||||
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
|
||||
continue
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
if hasattr(user_node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(shape_consistency_node,
|
||||
mod_dir=user_node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
|
||||
shape_consistency_node = mod_graph.create_node(
|
||||
"call_function",
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
|
||||
)
|
||||
if hasattr(user_node.meta["info"], "activation_checkpoint"):
|
||||
MetaInfo(
|
||||
shape_consistency_node,
|
||||
mod_dir=user_node.meta["info"].mod_dir,
|
||||
activation_checkpoint=tuple(user_node.meta["info"].activation_checkpoint),
|
||||
)
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
@@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
_, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes)
|
||||
|
||||
for node in nodes:
|
||||
if not hasattr(node, 'best_strategy') or node.op == 'output':
|
||||
if not hasattr(node, "best_strategy") or node.op == "output":
|
||||
continue
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
for op_data, comm_action in comm_actions.items():
|
||||
|
||||
if comm_action.comm_type == CommType.HOOK:
|
||||
continue
|
||||
if comm_action.comm_type == CommType.BEFORE:
|
||||
@@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
else:
|
||||
comm_object = node.args[comm_action.arg_index]
|
||||
with mod_graph.inserting_before(node):
|
||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||
runtime_comm_spec_apply,
|
||||
args=(comm_object, comm_actions_dict_node,
|
||||
node_to_index_dict[node], op_data.name))
|
||||
comm_spec_apply_node = mod_graph.create_node(
|
||||
"call_function",
|
||||
runtime_comm_spec_apply,
|
||||
args=(comm_object, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
|
||||
)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if comm_action.key_for_kwarg is not None:
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
@@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
|
||||
elif comm_action.comm_type == CommType.AFTER:
|
||||
with mod_graph.inserting_after(node):
|
||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||
runtime_comm_spec_apply,
|
||||
args=(node, comm_actions_dict_node,
|
||||
node_to_index_dict[node], op_data.name))
|
||||
comm_spec_apply_node = mod_graph.create_node(
|
||||
"call_function",
|
||||
runtime_comm_spec_apply,
|
||||
args=(node, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
|
||||
)
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
if user == comm_spec_apply_node:
|
||||
@@ -211,10 +212,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs[str(node)] = comm_spec_apply_node
|
||||
user.kwargs = new_kwargs
|
||||
if hasattr(node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(comm_spec_apply_node,
|
||||
mod_dir=node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
|
||||
if hasattr(node.meta["info"], "activation_checkpoint"):
|
||||
MetaInfo(
|
||||
comm_spec_apply_node,
|
||||
mod_dir=node.meta["info"].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
|
||||
)
|
||||
|
||||
return gm
|
||||
|
||||
@@ -227,21 +230,21 @@ def _act_annotation_pass(gm: torch.fx.GraphModule):
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
|
||||
for node in nodes:
|
||||
if not hasattr(node.meta, 'activation_checkpoint'):
|
||||
from .runtime_preparation_pass import size_processing
|
||||
if not hasattr(node.meta, "activation_checkpoint"):
|
||||
pass
|
||||
|
||||
user_act_annotation = -1
|
||||
input_act_annotation = -1
|
||||
for user_node in node.users.keys():
|
||||
if 'activation_checkpoint' in user_node.meta:
|
||||
user_act_annotation = user_node.meta['activation_checkpoint']
|
||||
if "activation_checkpoint" in user_node.meta:
|
||||
user_act_annotation = user_node.meta["activation_checkpoint"]
|
||||
break
|
||||
for input_node in node._input_nodes.keys():
|
||||
if 'activation_checkpoint' in input_node.meta:
|
||||
input_act_annotation = input_node.meta['activation_checkpoint']
|
||||
if "activation_checkpoint" in input_node.meta:
|
||||
input_act_annotation = input_node.meta["activation_checkpoint"]
|
||||
break
|
||||
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
|
||||
node.meta['activation_checkpoint'] = user_act_annotation
|
||||
node.meta["activation_checkpoint"] = user_act_annotation
|
||||
|
||||
return gm
|
||||
|
||||
|
Reference in New Issue
Block a user