[autoparallel] handled illegal sharding strategy in shape consistency (#1744)

* [autoparallel] handled illegal sharding strategy in shape consistency

* polish code
This commit is contained in:
Frank Lee
2022-10-20 12:06:25 +08:00
committed by GitHub
parent 88a79814fb
commit 993b8875b6
4 changed files with 109 additions and 89 deletions

View File

@@ -1,15 +1,17 @@
from ast import NodeTransformer
import torch
from typing import List
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
import builtins
import operator
from ast import NodeTransformer
from copy import deepcopy
from typing import List
import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.split_module import split_module
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
shape_consistency_manager = ShapeConsistencyManager()