mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
@@ -33,6 +35,7 @@ def recover_narrow(gm, narrow_node):
|
||||
return gm
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
def test_size_value_converting_pass():
|
||||
model = TestModule()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
@@ -40,14 +43,14 @@ def test_size_value_converting_pass():
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
meta_args = {'x': torch.rand(4, 8).to('meta')}
|
||||
input = torch.rand(4, 8)
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
|
||||
x_node = list(graph.nodes)[0]
|
||||
x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
|
||||
setattr(x_node, 'sharding_spec', x_sharding_spec)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm = insert_narrow(gm, x_node)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
gm.recompile()
|
||||
size = gm(input)
|
||||
assert size == torch.Size([2, 8])
|
||||
|
Reference in New Issue
Block a user