diff --git a/colossalai/fx/passes/passes_for_gpt2_test.py b/colossalai/fx/passes/passes_for_gpt2_test.py index f98fcd686..abc1a089e 100644 --- a/colossalai/fx/passes/passes_for_gpt2_test.py +++ b/colossalai/fx/passes/passes_for_gpt2_test.py @@ -1,15 +1,16 @@ -import torch -from torch.fx.graph_module import GraphModule -from typing import Callable, List, Dict, Any, Optional -from torch.fx._compatibility import compatibility -from packaging import version -from colossalai.fx.passes.meta_info_prop import TensorMetadata import inspect -from typing import List -from colossalai.fx.passes.split_module import Partition -from colossalai.fx.passes.adding_split_node_pass import pipe_split, balanced_split_pass +from typing import Any, Callable, Dict, List, Optional + +import torch +from packaging import version +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule from torch.fx.node import Node +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split +from colossalai.fx.passes.meta_info_prop import TensorMetadata +from colossalai.fx.passes.split_module import Partition + def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]): '''