mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[autoparallel] Attach input, buffer and output tensor to MetaInfo class (#2162)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel * [fx] add conv metainfo class * [fx] restore profiler * [fx] restore meta profiler * [autoparallel] modify unit test * [fx] modify unit test * [autoparallel] add batchnorm metainfo class * [autoparallel] fix batchnorm unit test function declaration * [fx] restore profiler * [fx] add relu metainfo class * [fx] restore profiler * [autoparallel] modify metainfo input * [autoparallel] add pooling metainfo * [autoparallel] add F.linear metainfo generator * [autoparallel] add binary elementwise metainfo * [fx] recover profiler * [autoparallel] fix forward memory calculation * [autoparallel] modify constants.py * [autoparallel] remove redundant print * [autoparallel] add F.conv metainfo * [autoparallel] linear fix * [autoparallel] memory estimation for communication actions * [autoparallel] fix docstring * [autoparallel] fix variables name * [autoparallel] attach tensor to metainfo class * [autoparallel] fix dangerous try except * [autoparallel] attach memory cost to shape consistency node * [autoparallel] attach shape consistency node's metainfo to the node * [autoparallel] remove todo in shape consistency memory estimation * [autoparallel] fix the annotation
This commit is contained in:
@@ -407,9 +407,6 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
|
||||
def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem:
|
||||
"""memory cost of the communication action sequence
|
||||
TODO: Currently we just consider tensor numel in the shape consistency manger,
|
||||
as the manager itself doesn't have the access to tensor dtype, we need to take
|
||||
it into consideration in memory estimation.
|
||||
|
||||
Args:
|
||||
comm_action_sequence (List[CommSpec]): list of communication actions
|
||||
@@ -420,9 +417,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
|
||||
def compute_shape(sharding_spec: ShardingSpec):
|
||||
shape = sharding_spec.entire_shape
|
||||
new_shape = []
|
||||
for dim, shard in sharding_spec.dim_partition_dict.items():
|
||||
shape[dim] = shape[dim] // len(shard)
|
||||
return shape
|
||||
new_shape.append(shape[dim] // len(shard))
|
||||
return new_shape
|
||||
|
||||
def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int):
|
||||
"""analyze all_gather memory footprint
|
||||
@@ -461,7 +459,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
# generate a new tensor
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = np.prod(input_shape)
|
||||
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axes]
|
||||
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
|
||||
alloc_numel += output_numel
|
||||
peak_numel = max(peak_numel, alloc_numel)
|
||||
if discard_input:
|
||||
@@ -538,8 +536,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
# analyze memory footprint of forward comm actions sequence
|
||||
fwd_alloc_numel = 0
|
||||
fwd_peak_numel = 0
|
||||
for idx, fwd_action, comm_spec in enumerate(zip(fwd_actions, comm_action_sequence)):
|
||||
for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)):
|
||||
# the first forward comm action will not discard input
|
||||
fwd_action, comm_spec = action_spec_pair
|
||||
if idx == 0:
|
||||
fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel)
|
||||
else:
|
||||
@@ -548,7 +547,8 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
# analyze memory footprint for backward comm actions sequence
|
||||
bwd_alloc_numel = 0
|
||||
bwd_peak_numel = 0
|
||||
for idx, bwd_action, comm_spec in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
||||
for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))):
|
||||
bwd_action, comm_spec = action_spec_pair
|
||||
bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel)
|
||||
|
||||
fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel)
|
||||
|
Reference in New Issue
Block a user