[autoparallel] support distributed dataloader option (#1906)

* [autoparallel] support distributed dataloader option

* update output handler to support ddp dataloader

* poish code
This commit is contained in:
YuliangLiu0306
2022-11-17 20:11:53 +08:00
committed by GitHub
parent 6630d45546
commit 0da1d00399
18 changed files with 257 additions and 61 deletions

View File

@@ -84,7 +84,7 @@ def check_linear_module(rank, world_size, port):
gm.recompile()
node_list = list(graph.nodes)
solver_options = SolverOptions(fast=True)
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
linear_node = node_list[3]
@@ -138,7 +138,7 @@ def check_conv_module(rank, world_size, port):
node_list = list(graph.nodes)
conv_node = node_list[3]
solver_options = SolverOptions(fast=True)
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()

View File

@@ -36,7 +36,7 @@ def mem_test_for_node_strategy(rank: int,
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
solver_options = SolverOptions(fast=True)
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
target_node = list(graph.nodes)[node_index]

View File

@@ -1,11 +1,11 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import \
OuputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OuputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
class OutputModel(nn.Module):
@@ -18,7 +18,9 @@ class OutputModel(nn.Module):
return x, y
def test_output_handler():
@parameterize('output_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use()
def test_output_handler(output_option):
model = OutputModel()
tracer = ColoTracer()
# graph():
@@ -37,7 +39,10 @@ def test_output_handler():
output_strategies_vector = StrategiesVector(output_node)
# build handler
otuput_handler = OuputHandler(node=output_node, device_mesh=device_mesh, strategies_vector=output_strategies_vector)
otuput_handler = OuputHandler(node=output_node,
device_mesh=device_mesh,
strategies_vector=output_strategies_vector,
output_option=output_option)
otuput_handler.register_strategy(compute_resharding_cost=False)
# check operation data mapping
@@ -49,10 +54,12 @@ def test_output_handler():
assert op_data.data is not None
assert mapping['output'].name == "output"
assert mapping['output'].data.is_meta
assert mapping['output'].type == OperationDataType.OUTPUT
strategy_name_list = [val.name for val in otuput_handler.strategies_vector]
assert "Replica Output" in strategy_name_list
if output_option == 'distributed':
assert "Distributed Output" in strategy_name_list
else:
assert "Replica Output" in strategy_name_list
if __name__ == '__main__':

View File

@@ -1,11 +1,11 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import \
PlacehodlerHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
class PlaceholderModel(nn.Module):
@@ -17,7 +17,9 @@ class PlaceholderModel(nn.Module):
return input
def test_placeholder_handler():
@parameterize('placeholder_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use()
def test_placeholder_handler(placeholder_option):
model = PlaceholderModel()
tracer = ColoTracer()
# graph():
@@ -33,16 +35,25 @@ def test_placeholder_handler():
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
placeholder_node = list(graph.nodes)[0]
placeholder_strategies_vector = StrategiesVector(placeholder_node)
# build handler
placeholder_handler = PlacehodlerHandler(node=placeholder_node,
device_mesh=device_mesh,
strategies_vector=placeholder_strategies_vector)
strategies_vector=placeholder_strategies_vector,
placeholder_option=placeholder_option)
placeholder_handler.register_strategy(compute_resharding_cost=False)
# check operation data mapping
mapping = placeholder_handler.get_operation_data_mapping()
strategy = placeholder_strategies_vector[0]
strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping['output'].name)
if placeholder_option == 'distributed':
assert str(strategy_sharding_spec.sharding_sequence) == '[S01, R, R, R]'
else:
assert str(strategy_sharding_spec.sharding_sequence) == '[R, R, R, R]'
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
@@ -53,7 +64,10 @@ def test_placeholder_handler():
assert mapping['output'].data.shape == torch.Size((4, 4, 64, 64))
assert mapping['output'].type == OperationDataType.OUTPUT
strategy_name_list = [val.name for val in placeholder_handler.strategies_vector]
assert "Replica Placeholder" in strategy_name_list
if placeholder_option == 'replicated':
assert "Replica Placeholder" in strategy_name_list
else:
assert "Distributed Placeholder" in strategy_name_list
if __name__ == '__main__':

View File

@@ -79,7 +79,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
solver_options = SolverOptions(fast=True)
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
target_node = list(graph.nodes)[node_index]

View File

@@ -79,7 +79,7 @@ def test_linear_module():
gm.recompile()
node_list = list(graph.nodes)
solver_options = SolverOptions(fast=True)
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
linear_node = node_list[3]
@@ -117,7 +117,7 @@ def test_conv_module():
gm.recompile()
node_list = list(graph.nodes)
conv_node = node_list[3]
solver_options = SolverOptions(fast=True)
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
_param_resharding_cost_assertion(conv_node)

View File

@@ -138,7 +138,7 @@ def check_apply_bottleneck(rank, world_size, port):
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = SolverOptions(fast=True)
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
@@ -162,7 +162,7 @@ def check_apply_bottleneck(rank, world_size, port):
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
assert output.shape == origin_output.shape
assert_close(output, origin_output)
assert_close(output, origin_output, rtol=1e-03, atol=1e-05)
print("*******************backward starting*******************")
cuda_rng_state = torch.cuda.get_rng_state()
output.sum().backward()

View File

@@ -60,7 +60,7 @@ def check_apply(rank, world_size, port):
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = SolverOptions(fast=True)
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()

View File

@@ -3,8 +3,13 @@ from torch.fx import GraphModule
from torchvision.models import resnet50
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
StrategiesConstructor)
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
@@ -53,7 +58,7 @@ def test_cost_graph():
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions(fast=True)
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()