[autoparallel] refactored shape consistency to remove redundancy (#1591)

* [autoparallel] refactored shape consistency to remove redundancy

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2022-09-13 18:30:18 +08:00
committed by GitHub
parent d164449d00
commit 27fe8af60c
13 changed files with 220 additions and 234 deletions

View File

@@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
@@ -31,7 +30,6 @@ def test_bn_handler():
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = BNModel(16)
@@ -77,10 +75,11 @@ def test_bn_handler():
# generate bn strategy
strategies_vector = StrategiesVector(node=nodes[2])
bn_handler = BatchNormHandler(node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
shape_consistency_manager=shape_consistency_manager)
bn_handler = BatchNormHandler(
node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
)
bn_handler.register_strategy()
# ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
# 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']

View File

@@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
@@ -31,7 +30,6 @@ def test_conv_handler():
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
@@ -77,10 +75,11 @@ def test_conv_handler():
# generate conv strategy
strategies_vector = StrategiesVector(node=nodes[2])
conv_handler = ConvHandler(node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
shape_consistency_manager=shape_consistency_manager)
conv_handler = ConvHandler(
node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
)
conv_handler.register_strategy()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]

View File

@@ -4,10 +4,7 @@ from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
@@ -37,7 +34,6 @@ def test_cost_graph():
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
@@ -55,7 +51,7 @@ def test_cost_graph():
gm.recompile()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
# (x, mul):{(0, 0): 0}

View File

@@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
@@ -31,7 +30,6 @@ def test_dot_handler():
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = LinearModel(8, 16)
@@ -76,10 +74,11 @@ def test_dot_handler():
# generate dot strategy
strategies_vector = StrategiesVector(node=nodes[2])
dot_handler = DotHandler(node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
shape_consistency_manager=shape_consistency_manager)
dot_handler = DotHandler(
node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
)
strategies_vector = dot_handler.register_strategy()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']

View File

@@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.options import SolverOptions
@@ -34,7 +33,6 @@ def test_strategies_constructor():
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
@@ -49,7 +47,7 @@ def test_strategies_constructor():
gm.recompile()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
assert strategies_constructor.leaf_strategies == []
assert strategies_constructor.strategy_map == {}