[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -14,18 +14,20 @@ from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
target_sharding_spec: ShardingSpec) -> ShardMetaInfo:
def _construct_shard_meta_info(
node: Node, origin_sharding_spec: ShardingSpec, target_sharding_spec: ShardingSpec
) -> ShardMetaInfo:
# get comm_action_sequence and total_cost from shape_consistency_manager
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
origin_sharding_spec, target_sharding_spec
)
meta_info = ShardMetaInfo()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for ShardMetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
# extract user that has _meta_data and extract element length
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
input_node = next(n for n in node._input_nodes if hasattr(n, "_meta_data"))
element_length = input_node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
@@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
meta_info.memory_cost = mem_cost
# get computation cost for ShardMetaInfo
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
meta_info.compute_cost = TrainCycleItem(
total_cost["forward"] * element_length,
total_cost["backward"] * element_length,
total_cost["total"] * element_length,
)
# get tensor shape for ShardMetaInfo
origin_sharding_spec: ShardingSpec
@@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
output_shape = target_sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
return meta_info
@@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
# extract node index and user node index
args = node.args
node_index, user_node_index = args[3], args[4]
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
user_node_index]
origin_sharding_spec, target_sharding_spec = (
origin_spec_dict[node_index],
sharding_spec_dict[node_index][user_node_index],
)
return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
@@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S
# this case is for all_reduce, there will be no memory cost
meta_info = ShardMetaInfo()
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
output_node = next(n for n in node.users if hasattr(n, "_meta_data"))
element_length = output_node._meta_data.element_size()
total_cost = comm_action.comm_spec.get_comm_cost()
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
meta_info.compute_cost = TrainCycleItem(
total_cost["forward"] * element_length,
total_cost["backward"] * element_length,
total_cost["total"] * element_length,
)
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
else:
# this case will be handled by shape consistency manager
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
'tgt_spec']
origin_sharding_spec, target_sharding_spec = (
comm_action.comm_spec["src_spec"],
comm_action.comm_spec["tgt_spec"],
)
meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
return meta_info
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
comm_actions_dict: Dict) -> GraphModule:
def comm_metainfo_pass(
gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, comm_actions_dict: Dict
) -> GraphModule:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for node in gm.graph.nodes:
if node.target == runtime_apply:
setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
setattr(node, "best_strategy_info", _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply:
setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
setattr(node, "best_strategy_info", _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else:
pass
return gm