[hotfix] fix wrong type name in profiler (#1678)

This commit is contained in:
Boyuan Yao 2022-10-05 21:59:05 +08:00 committed by GitHub
parent 132b4306b7
commit d8420f81a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
from functools import partial from functools import partial
from typing import Callable, Any, Dict, Tuple from typing import Callable, Any, Dict, Tuple
import torch import torch
from torch.nn.parameter import Parameter
from torch.fx import Graph, Node from torch.fx import Graph, Node
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
@ -298,7 +299,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
param_size = 0 param_size = 0
def get_param_size(x): def get_param_size(x):
if isinstance(x, torch.nn.parameter): if isinstance(x, Parameter):
param_size += activation_size(x) param_size += activation_size(x)
tree_map(get_param_size, args) tree_map(get_param_size, args)