[fx] add balanced policy v2 (#1251)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx] add balanced policy v2

* add unittest
This commit is contained in:
YuliangLiu0306
2022-07-15 14:54:26 +08:00
committed by GitHub
parent ca2d3f284f
commit e8acf55e8b
3 changed files with 54 additions and 3 deletions

View File

@@ -67,7 +67,6 @@ class MetaInfoProp(torch.fx.Interpreter):
def run_node(self, n: Node) -> Any:
result = super().run_node(n)
found_tensor = False
def extract_tensor_meta(obj):
@@ -83,7 +82,25 @@ class MetaInfoProp(torch.fx.Interpreter):
n.meta['tensor_meta'] = meta
else:
n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)
# counting the total size of node outputs
total_node_size = 0
if isinstance(n.meta['tensor_meta'], TensorMetadata):
total_node_size += n.meta['tensor_meta'].numel
else:
for element in n.meta['tensor_meta']:
assert isinstance(
element, TensorMetadata
), f"``n.meta['tensor_meta']`` should be either TensorMetadata or a tuple of TensorMetadata."
total_node_size += element.numel
# counting the total size of parameters
total_param_size = 0
if n.op == 'call_module':
target_module = n.graph.owning_module.get_submodule(n.target)
for param in target_module.parameters():
total_param_size += param.numel()
total_node_size += total_param_size
n.node_size = total_node_size
n.meta['type'] = type(result)
return result