mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 19:55:03 +00:00
[autoparallel] Attach input, buffer and output tensor to MetaInfo class (#2162)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel * [fx] add conv metainfo class * [fx] restore profiler * [fx] restore meta profiler * [autoparallel] modify unit test * [fx] modify unit test * [autoparallel] add batchnorm metainfo class * [autoparallel] fix batchnorm unit test function declaration * [fx] restore profiler * [fx] add relu metainfo class * [fx] restore profiler * [autoparallel] modify metainfo input * [autoparallel] add pooling metainfo * [autoparallel] add F.linear metainfo generator * [autoparallel] add binary elementwise metainfo * [fx] recover profiler * [autoparallel] fix forward memory calculation * [autoparallel] modify constants.py * [autoparallel] remove redundant print * [autoparallel] add F.conv metainfo * [autoparallel] linear fix * [autoparallel] memory estimation for communication actions * [autoparallel] fix docstring * [autoparallel] fix variables name * [autoparallel] attach tensor to metainfo class * [autoparallel] fix dangerous try except * [autoparallel] attach memory cost to shape consistency node * [autoparallel] attach shape consistency node's metainfo to the node * [autoparallel] remove todo in shape consistency memory estimation * [autoparallel] fix the annotation
This commit is contained in:
@@ -59,10 +59,12 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
||||
|
||||
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
# store_fwd_in
|
||||
fwd_in = [input_tensor]
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = []
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, mem_cost, fwd_in
|
||||
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.MaxPool1d)
|
||||
@@ -122,7 +124,9 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
||||
|
||||
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
# store_fwd_in
|
||||
fwd_in = [input_tensor]
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, mem_cost, fwd_in
|
||||
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
Reference in New Issue
Block a user