[autoparallel] Attach input, buffer and output tensor to MetaInfo class (#2162)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test

* [autoparallel] add batchnorm metainfo class

* [autoparallel] fix batchnorm unit test function declaration

* [fx] restore profiler

* [fx] add relu metainfo class

* [fx] restore profiler

* [autoparallel] modify metainfo input

* [autoparallel] add pooling metainfo

* [autoparallel] add F.linear metainfo generator

* [autoparallel] add binary elementwise metainfo

* [fx] recover profiler

* [autoparallel] fix forward memory calculation

* [autoparallel] modify constants.py

* [autoparallel] remove redundant print

* [autoparallel] add F.conv metainfo

* [autoparallel] linear fix

* [autoparallel] memory estimation for communication actions

* [autoparallel] fix docstring

* [autoparallel] fix variables name

* [autoparallel] attach tensor to metainfo class

* [autoparallel] fix dangerous try except

* [autoparallel] attach memory cost to shape consistency node

* [autoparallel] attach shape consistency node's metainfo to the node

* [autoparallel] remove todo in shape consistency memory estimation

* [autoparallel] fix the annotation
This commit is contained in:
Boyuan Yao
2022-12-28 13:37:40 +08:00
committed by GitHub
parent d0bc5a1b34
commit 24246f7aa5
11 changed files with 118 additions and 44 deletions

View File

@@ -407,9 +407,6 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem:
"""memory cost of the communication action sequence
TODO: Currently we just consider tensor numel in the shape consistency manger,
as the manager itself doesn't have the access to tensor dtype, we need to take
it into consideration in memory estimation.
Args:
comm_action_sequence (List[CommSpec]): list of communication actions
@@ -420,9 +417,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
def compute_shape(sharding_spec: ShardingSpec):
shape = sharding_spec.entire_shape
new_shape = []
for dim, shard in sharding_spec.dim_partition_dict.items():
shape[dim] = shape[dim] // len(shard)
return shape
new_shape.append(shape[dim] // len(shard))
return new_shape
def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
"""analyze all_gather memory footprint
@@ -461,7 +459,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# generate a new tensor
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axes]
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
alloc_numel += output_numel
peak_numel = max(peak_numel, alloc_numel)
if discard_input:
@@ -538,8 +536,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# analyze memory footprint of forward comm actions sequence
fwd_alloc_numel = 0
fwd_peak_numel = 0
for idx, fwd_action, comm_spec in enumerate(zip(fwd_actions, comm_action_sequence)):
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
# the first forward comm action will not discard input
fwd_action, comm_spec = action_spec_pair
if idx == 0:
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
else:
@@ -548,7 +547,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# analyze memory footprint for backward comm actions sequence
bwd_alloc_numel = 0
bwd_peak_numel = 0
for idx, bwd_action, comm_spec in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
bwd_action, comm_spec = action_spec_pair
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)