mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 15:32:22 +00:00
[NFC] polish colossalai/fx/passes/split_module.py code style (#3263)
Co-authored-by: csric <richcsr256@gmail.com>
This commit is contained in:
parent
488f37048c
commit
00778abc48
@ -1,9 +1,10 @@
|
||||
import torch
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from typing import Callable, List, Dict, Any, Optional
|
||||
from torch.fx._compatibility import compatibility
|
||||
from packaging import version
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
@ -133,9 +134,7 @@ def split_module(
|
||||
if def_partition_name is not None:
|
||||
use_partition.partitions_dependent_on.setdefault(def_partition_name)
|
||||
|
||||
def record_output(
|
||||
def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]
|
||||
): # noqa: B950
|
||||
def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
|
||||
def_partition_name = getattr(def_node, "_fx_partition", None)
|
||||
use_partition_name = getattr(use_node, "_fx_partition", None)
|
||||
if def_partition_name != use_partition_name:
|
||||
|
Loading…
Reference in New Issue
Block a user