mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[autoparallel] refactored shape consistency to remove redundancy (#1591)
* [autoparallel] refactored shape consistency to remove redundancy * polish code * polish code * polish code
This commit is contained in:
@@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
@@ -31,7 +30,6 @@ def test_bn_handler():
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = BNModel(16)
|
||||
@@ -77,10 +75,11 @@ def test_bn_handler():
|
||||
|
||||
# generate bn strategy
|
||||
strategies_vector = StrategiesVector(node=nodes[2])
|
||||
bn_handler = BatchNormHandler(node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
shape_consistency_manager=shape_consistency_manager)
|
||||
bn_handler = BatchNormHandler(
|
||||
node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
)
|
||||
bn_handler.register_strategy()
|
||||
# ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
|
||||
# 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']
|
||||
|
@@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
@@ -31,7 +30,6 @@ def test_conv_handler():
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
@@ -77,10 +75,11 @@ def test_conv_handler():
|
||||
|
||||
# generate conv strategy
|
||||
strategies_vector = StrategiesVector(node=nodes[2])
|
||||
conv_handler = ConvHandler(node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
shape_consistency_manager=shape_consistency_manager)
|
||||
conv_handler = ConvHandler(
|
||||
node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
)
|
||||
conv_handler.register_strategy()
|
||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
|
||||
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]
|
||||
|
@@ -4,10 +4,7 @@ from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
@@ -37,7 +34,6 @@ def test_cost_graph():
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
@@ -55,7 +51,7 @@ def test_cost_graph():
|
||||
gm.recompile()
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
# (x, mul):{(0, 0): 0}
|
||||
|
@@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
@@ -31,7 +30,6 @@ def test_dot_handler():
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 8))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = LinearModel(8, 16)
|
||||
@@ -76,10 +74,11 @@ def test_dot_handler():
|
||||
|
||||
# generate dot strategy
|
||||
strategies_vector = StrategiesVector(node=nodes[2])
|
||||
dot_handler = DotHandler(node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
shape_consistency_manager=shape_consistency_manager)
|
||||
dot_handler = DotHandler(
|
||||
node=nodes[2],
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
)
|
||||
strategies_vector = dot_handler.register_strategy()
|
||||
|
||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
|
||||
|
@@ -8,7 +8,6 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
@@ -34,7 +33,6 @@ def test_strategies_constructor():
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 16, 64, 64))
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(16, 32)
|
||||
@@ -49,7 +47,7 @@ def test_strategies_constructor():
|
||||
gm.recompile()
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
|
||||
assert strategies_constructor.leaf_strategies == []
|
||||
assert strategies_constructor.strategy_map == {}
|
||||
|
Reference in New Issue
Block a user