mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 07:22:12 +00:00
[NFC] polish colossalai/fx/passes/split_module.py code style (#3263)
Co-authored-by: csric <richcsr256@gmail.com>
This commit is contained in:
parent
488f37048c
commit
00778abc48
@ -1,9 +1,10 @@
|
|||||||
import torch
|
|
||||||
from torch.fx.graph_module import GraphModule
|
|
||||||
from typing import Callable, List, Dict, Any, Optional
|
|
||||||
from torch.fx._compatibility import compatibility
|
|
||||||
from packaging import version
|
|
||||||
import inspect
|
import inspect
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
from torch.fx._compatibility import compatibility
|
||||||
|
from torch.fx.graph_module import GraphModule
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
@ -38,7 +39,7 @@ def split_module(
|
|||||||
m: GraphModule,
|
m: GraphModule,
|
||||||
root_m: torch.nn.Module,
|
root_m: torch.nn.Module,
|
||||||
split_callback: Callable[[torch.fx.node.Node], int],
|
split_callback: Callable[[torch.fx.node.Node], int],
|
||||||
merge_output = False,
|
merge_output=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
|
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
|
||||||
@ -132,10 +133,8 @@ def split_module(
|
|||||||
use_partition.inputs.setdefault(def_node.name)
|
use_partition.inputs.setdefault(def_node.name)
|
||||||
if def_partition_name is not None:
|
if def_partition_name is not None:
|
||||||
use_partition.partitions_dependent_on.setdefault(def_partition_name)
|
use_partition.partitions_dependent_on.setdefault(def_partition_name)
|
||||||
|
|
||||||
def record_output(
|
def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
|
||||||
def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]
|
|
||||||
): # noqa: B950
|
|
||||||
def_partition_name = getattr(def_node, "_fx_partition", None)
|
def_partition_name = getattr(def_node, "_fx_partition", None)
|
||||||
use_partition_name = getattr(use_node, "_fx_partition", None)
|
use_partition_name = getattr(use_node, "_fx_partition", None)
|
||||||
if def_partition_name != use_partition_name:
|
if def_partition_name != use_partition_name:
|
||||||
@ -291,7 +290,7 @@ def split_module(
|
|||||||
|
|
||||||
for partition_name in sorted_partitions:
|
for partition_name in sorted_partitions:
|
||||||
partition = partitions[partition_name]
|
partition = partitions[partition_name]
|
||||||
|
|
||||||
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
||||||
|
|
||||||
return new_gm
|
return new_gm
|
||||||
|
Loading…
Reference in New Issue
Block a user