mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -6,6 +6,7 @@ from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
@@ -26,6 +27,7 @@ def insert_narrow(gm, x_node):
|
||||
return gm
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def test_node_args_converting_pass():
|
||||
model = TestModule()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
@@ -36,6 +37,7 @@ def recover_narrow(gm, narrow_node):
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@clear_cache_before_run()
|
||||
def test_size_value_converting_pass():
|
||||
model = TestModule()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
Reference in New Issue
Block a user