1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-09-11 13:59:08 +00:00

[hotfix] fix wrong type name in profiler ()

This commit is contained in:
Boyuan Yao
2022-10-05 21:59:05 +08:00
committed by GitHub
parent 132b4306b7
commit d8420f81a4

@@ -1,6 +1,7 @@
from functools import partial from functools import partial
from typing import Callable, Any, Dict, Tuple from typing import Callable, Any, Dict, Tuple
import torch import torch
from torch.nn.parameter import Parameter
from torch.fx import Graph, Node from torch.fx import Graph, Node
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
@@ -298,7 +299,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
param_size = 0 param_size = 0
def get_param_size(x): def get_param_size(x):
if isinstance(x, torch.nn.parameter): if isinstance(x, Parameter):
param_size += activation_size(x) param_size += activation_size(x)
tree_map(get_param_size, args) tree_map(get_param_size, args)