[autoparallel]integrate auto parallel feature with new tracer (#3408)

* [autoparallel] integrate new analyzer in module level

* unify the profiling method

* polish

* fix no codegen bug

* fix pass bug

* fix liveness test

* polish
This commit is contained in:
YuliangLiu0306
2023-04-04 17:40:45 +08:00
committed by GitHub
parent 573af84184
commit ffcdbf0f65
46 changed files with 396 additions and 470 deletions

View File

@@ -1,10 +1,12 @@
import pytest
import torch
import torch.nn.functional as F
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
@@ -33,6 +35,7 @@ def recover_narrow(gm, narrow_node):
return gm
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
def test_size_value_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)
@@ -40,14 +43,14 @@ def test_size_value_converting_pass():
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
meta_args = {'x': torch.rand(4, 8).to('meta')}
input = torch.rand(4, 8)
tracer = ColoTracer()
tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(root=model, meta_args=meta_args)
x_node = list(graph.nodes)[0]
x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
setattr(x_node, 'sharding_spec', x_sharding_spec)
gm = ColoGraphModule(model, graph)
gm = insert_narrow(gm, x_node)
shape_prop_pass(gm, *meta_args.values())
gm.recompile()
size = gm(input)
assert size == torch.Size([2, 8])