[autoparallel] Attach input, buffer and output tensor to MetaInfo class (#2162)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test

* [autoparallel] add batchnorm metainfo class

* [autoparallel] fix batchnorm unit test function declaration

* [fx] restore profiler

* [fx] add relu metainfo class

* [fx] restore profiler

* [autoparallel] modify metainfo input

* [autoparallel] add pooling metainfo

* [autoparallel] add F.linear metainfo generator

* [autoparallel] add binary elementwise metainfo

* [fx] recover profiler

* [autoparallel] fix forward memory calculation

* [autoparallel] modify constants.py

* [autoparallel] remove redundant print

* [autoparallel] add F.conv metainfo

* [autoparallel] linear fix

* [autoparallel] memory estimation for communication actions

* [autoparallel] fix docstring

* [autoparallel] fix variables name

* [autoparallel] attach tensor to metainfo class

* [autoparallel] fix dangerous try except

* [autoparallel] attach memory cost to shape consistency node

* [autoparallel] attach shape consistency node's metainfo to the node

* [autoparallel] remove todo in shape consistency memory estimation

* [autoparallel] fix the annotation
This commit is contained in:
Boyuan Yao
2022-12-28 13:37:40 +08:00
committed by GitHub
parent d0bc5a1b34
commit 24246f7aa5
11 changed files with 118 additions and 44 deletions

View File

@@ -4,11 +4,13 @@ from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import CommSpec
@@ -45,6 +47,52 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict:
return rst
def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
"""
This method is used to construct `MetaInto` for shape consistency node
TODO: Actually we could attain the cost information from resharding cost in node
handler, we should modify this part in the future.
"""
def compute_shape(sharding_spec: ShardingSpec):
shape = sharding_spec.entire_shape
new_shape = []
for dim, shard in sharding_spec.dim_partition_dict.items():
new_shape.append(shape[dim] // len(shard))
return new_shape
meta_info = MetaInfo()
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.sharding_spec
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for MetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
element_length = node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
mem_cost.fwd.temp *= element_length
mem_cost.bwd.activation *= element_length
mem_cost.bwd.temp *= element_length
mem_cost.total.activation *= element_length
meta_info.memory_cost = mem_cost
# get computation cost for MetaInfo
compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total'])
meta_info.compute_cost = compute_cost
# get tensor shape for MetaInfo
input_shape = compute_shape(origin_sharding_spec)
output_shape = compute_shape(target_sharding_spec)
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
return meta_info
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
"""
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
@@ -126,6 +174,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
meta_info = construct_meta_info(node, user_node)
setattr(shape_consistency_node, 'best_metainfo', meta_info)
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)