[NFC] polish colossalai/fx/passes/split_module.py code style (#3263)

Co-authored-by: csric <richcsr256@gmail.com>
This commit is contained in:
CsRic 2023-03-27 22:03:29 +08:00 committed by binmakeswell
parent 488f37048c
commit 00778abc48

View File

@ -1,9 +1,10 @@
import torch
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
from packaging import version
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional
import torch
from packaging import version
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
@ -133,9 +134,7 @@ def split_module(
if def_partition_name is not None: if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name) use_partition.partitions_dependent_on.setdefault(def_partition_name)
def record_output( def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]
): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None) def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None) use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name: if def_partition_name != use_partition_name: