mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[fx] add balanced policy v2 (#1251)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx] add balanced policy v2
* add unittest
This commit is contained in:
@@ -67,7 +67,6 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
def run_node(self, n: Node) -> Any:
|
||||
result = super().run_node(n)
|
||||
|
||||
found_tensor = False
|
||||
|
||||
def extract_tensor_meta(obj):
|
||||
@@ -83,7 +82,25 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
n.meta['tensor_meta'] = meta
|
||||
else:
|
||||
n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)
|
||||
# counting the total size of node outputs
|
||||
total_node_size = 0
|
||||
if isinstance(n.meta['tensor_meta'], TensorMetadata):
|
||||
total_node_size += n.meta['tensor_meta'].numel
|
||||
else:
|
||||
for element in n.meta['tensor_meta']:
|
||||
assert isinstance(
|
||||
element, TensorMetadata
|
||||
), f"``n.meta['tensor_meta']`` should be either TensorMetadata or a tuple of TensorMetadata."
|
||||
total_node_size += element.numel
|
||||
# counting the total size of parameters
|
||||
total_param_size = 0
|
||||
if n.op == 'call_module':
|
||||
target_module = n.graph.owning_module.get_submodule(n.target)
|
||||
for param in target_module.parameters():
|
||||
total_param_size += param.numel()
|
||||
|
||||
total_node_size += total_param_size
|
||||
n.node_size = total_node_size
|
||||
n.meta['type'] = type(result)
|
||||
return result
|
||||
|
||||
|
Reference in New Issue
Block a user