mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[autoparallel] update CommSpec to CommActions (#1768)
* [autoparallel] update CommSpec to CommActions * polish code
This commit is contained in:
@@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
||||
|
||||
@@ -109,6 +110,7 @@ def test_linear_module_handler(bias):
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@parameterize('bias', [True, False])
|
||||
def test_linear_function_handler(bias):
|
||||
model = nn.Linear(16, 32, bias=bias).to('meta')
|
||||
|
Reference in New Issue
Block a user