[autoparallel] Hook all meta information on ResNet nodes for auto activation checkpoint (#2248)

* [autoparallel] hook node meta on graph nodes for checkpoint solver

* [autoparallel] polish code

* [autoparallel] restore some node handlers

* colossalai/auto_parallel/passes/meta_info_prop.py

* [autoparallel] remove some unused import

* [autoparallel] hook bwd_mem_out
This commit is contained in:
Boyuan Yao
2023-01-02 16:25:18 +08:00
committed by GitHub
parent c8c79102f0
commit ab38aebace
6 changed files with 132 additions and 76 deletions

View File

@@ -47,53 +47,6 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict:
return rst
def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
"""
This method is used to construct `MetaInto` for shape consistency node
TODO: Actually we could attain the cost information from resharding cost in node
handler, we should modify this part in the future.
"""
def compute_shape(sharding_spec: ShardingSpec):
shape = sharding_spec.entire_shape
new_shape = []
for dim, shard in sharding_spec.dim_partition_dict.items():
new_shape.append(shape[dim] // len(shard))
return new_shape
meta_info = MetaInfo()
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name(
str(node.name))
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for MetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
element_length = node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
mem_cost.fwd.temp *= element_length
mem_cost.bwd.activation *= element_length
mem_cost.bwd.temp *= element_length
mem_cost.total.activation *= element_length
meta_info.memory_cost = mem_cost
# get computation cost for MetaInfo
compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total'])
meta_info.compute_cost = compute_cost
# get tensor shape for MetaInfo
input_shape = compute_shape(origin_sharding_spec)
output_shape = compute_shape(target_sharding_spec)
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
return meta_info
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
"""
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
@@ -175,8 +128,6 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
meta_info = construct_meta_info(node, user_node)
setattr(shape_consistency_node, 'best_metainfo', meta_info)
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)