mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[autoparallel] Hook all meta information on ResNet nodes for auto activation checkpoint (#2248)
* [autoparallel] hook node meta on graph nodes for checkpoint solver * [autoparallel] polish code * [autoparallel] restore some node handlers * colossalai/auto_parallel/passes/meta_info_prop.py * [autoparallel] remove some unused import * [autoparallel] hook bwd_mem_out
This commit is contained in:
@@ -47,53 +47,6 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict:
|
||||
return rst
|
||||
|
||||
|
||||
def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
|
||||
"""
|
||||
This method is used to construct `MetaInto` for shape consistency node
|
||||
TODO: Actually we could attain the cost information from resharding cost in node
|
||||
handler, we should modify this part in the future.
|
||||
"""
|
||||
|
||||
def compute_shape(sharding_spec: ShardingSpec):
|
||||
shape = sharding_spec.entire_shape
|
||||
new_shape = []
|
||||
for dim, shard in sharding_spec.dim_partition_dict.items():
|
||||
new_shape.append(shape[dim] // len(shard))
|
||||
return new_shape
|
||||
|
||||
meta_info = MetaInfo()
|
||||
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name(
|
||||
str(node.name))
|
||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||
origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||
# get mem cost for MetaInfo
|
||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||
element_length = node._meta_data.element_size()
|
||||
mem_cost.fwd.activation *= element_length
|
||||
mem_cost.fwd.temp *= element_length
|
||||
mem_cost.bwd.activation *= element_length
|
||||
mem_cost.bwd.temp *= element_length
|
||||
mem_cost.total.activation *= element_length
|
||||
|
||||
meta_info.memory_cost = mem_cost
|
||||
|
||||
# get computation cost for MetaInfo
|
||||
compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total'])
|
||||
meta_info.compute_cost = compute_cost
|
||||
|
||||
# get tensor shape for MetaInfo
|
||||
input_shape = compute_shape(origin_sharding_spec)
|
||||
output_shape = compute_shape(target_sharding_spec)
|
||||
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||
meta_info.fwd_buffer = []
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||
|
||||
return meta_info
|
||||
|
||||
|
||||
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
|
||||
"""
|
||||
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
|
||||
@@ -175,8 +128,6 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
meta_info = construct_meta_info(node, user_node)
|
||||
setattr(shape_consistency_node, 'best_metainfo', meta_info)
|
||||
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
|
Reference in New Issue
Block a user