[autoparallel] update CommSpec to CommActions (#1768)

* [autoparallel] update CommSpec to CommActions

* polish code
This commit is contained in:
YuliangLiu0306
2022-10-28 09:57:43 +08:00
committed by GitHub
parent 16b0abf94f
commit b0f7c8bde8
7 changed files with 267 additions and 122 deletions

View File

@@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
@@ -109,6 +110,7 @@ def test_linear_module_handler(bias):
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('bias', [True, False])
def test_linear_function_handler(bias):
model = nn.Linear(16, 32, bias=bias).to('meta')