mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[autoparallel] support distributed dataloader option (#1906)
* [autoparallel] support distributed dataloader option * update output handler to support ddp dataloader * poish code
This commit is contained in:
@@ -84,7 +84,7 @@ def check_linear_module(rank, world_size, port):
|
||||
gm.recompile()
|
||||
node_list = list(graph.nodes)
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
linear_node = node_list[3]
|
||||
@@ -138,7 +138,7 @@ def check_conv_module(rank, world_size, port):
|
||||
|
||||
node_list = list(graph.nodes)
|
||||
conv_node = node_list[3]
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
|
@@ -36,7 +36,7 @@ def mem_test_for_node_strategy(rank: int,
|
||||
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
||||
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
||||
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
target_node = list(graph.nodes)[node_index]
|
||||
|
@@ -1,11 +1,11 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import \
|
||||
OuputHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OuputHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
|
||||
|
||||
class OutputModel(nn.Module):
|
||||
@@ -18,7 +18,9 @@ class OutputModel(nn.Module):
|
||||
return x, y
|
||||
|
||||
|
||||
def test_output_handler():
|
||||
@parameterize('output_option', ['distributed', 'replicated'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_output_handler(output_option):
|
||||
model = OutputModel()
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
@@ -37,7 +39,10 @@ def test_output_handler():
|
||||
output_strategies_vector = StrategiesVector(output_node)
|
||||
|
||||
# build handler
|
||||
otuput_handler = OuputHandler(node=output_node, device_mesh=device_mesh, strategies_vector=output_strategies_vector)
|
||||
otuput_handler = OuputHandler(node=output_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=output_strategies_vector,
|
||||
output_option=output_option)
|
||||
|
||||
otuput_handler.register_strategy(compute_resharding_cost=False)
|
||||
# check operation data mapping
|
||||
@@ -49,10 +54,12 @@ def test_output_handler():
|
||||
assert op_data.data is not None
|
||||
|
||||
assert mapping['output'].name == "output"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
strategy_name_list = [val.name for val in otuput_handler.strategies_vector]
|
||||
assert "Replica Output" in strategy_name_list
|
||||
if output_option == 'distributed':
|
||||
assert "Distributed Output" in strategy_name_list
|
||||
else:
|
||||
assert "Replica Output" in strategy_name_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,11 +1,11 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import \
|
||||
PlacehodlerHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
|
||||
|
||||
class PlaceholderModel(nn.Module):
|
||||
@@ -17,7 +17,9 @@ class PlaceholderModel(nn.Module):
|
||||
return input
|
||||
|
||||
|
||||
def test_placeholder_handler():
|
||||
@parameterize('placeholder_option', ['distributed', 'replicated'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_placeholder_handler(placeholder_option):
|
||||
model = PlaceholderModel()
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
@@ -33,16 +35,25 @@ def test_placeholder_handler():
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
placeholder_node = list(graph.nodes)[0]
|
||||
placeholder_strategies_vector = StrategiesVector(placeholder_node)
|
||||
|
||||
# build handler
|
||||
placeholder_handler = PlacehodlerHandler(node=placeholder_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=placeholder_strategies_vector)
|
||||
strategies_vector=placeholder_strategies_vector,
|
||||
placeholder_option=placeholder_option)
|
||||
|
||||
placeholder_handler.register_strategy(compute_resharding_cost=False)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = placeholder_handler.get_operation_data_mapping()
|
||||
|
||||
strategy = placeholder_strategies_vector[0]
|
||||
strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping['output'].name)
|
||||
|
||||
if placeholder_option == 'distributed':
|
||||
assert str(strategy_sharding_spec.sharding_sequence) == '[S01, R, R, R]'
|
||||
else:
|
||||
assert str(strategy_sharding_spec.sharding_sequence) == '[R, R, R, R]'
|
||||
|
||||
for name, op_data in mapping.items():
|
||||
op_data: OperationData
|
||||
# make sure they have valid values
|
||||
@@ -53,7 +64,10 @@ def test_placeholder_handler():
|
||||
assert mapping['output'].data.shape == torch.Size((4, 4, 64, 64))
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
strategy_name_list = [val.name for val in placeholder_handler.strategies_vector]
|
||||
assert "Replica Placeholder" in strategy_name_list
|
||||
if placeholder_option == 'replicated':
|
||||
assert "Replica Placeholder" in strategy_name_list
|
||||
else:
|
||||
assert "Distributed Placeholder" in strategy_name_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -79,7 +79,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
||||
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
||||
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
||||
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
target_node = list(graph.nodes)[node_index]
|
||||
|
@@ -79,7 +79,7 @@ def test_linear_module():
|
||||
gm.recompile()
|
||||
node_list = list(graph.nodes)
|
||||
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
linear_node = node_list[3]
|
||||
@@ -117,7 +117,7 @@ def test_conv_module():
|
||||
gm.recompile()
|
||||
node_list = list(graph.nodes)
|
||||
conv_node = node_list[3]
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
_param_resharding_cost_assertion(conv_node)
|
||||
|
@@ -138,7 +138,7 @@ def check_apply_bottleneck(rank, world_size, port):
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
@@ -162,7 +162,7 @@ def check_apply_bottleneck(rank, world_size, port):
|
||||
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
|
||||
assert output.shape == origin_output.shape
|
||||
assert_close(output, origin_output)
|
||||
assert_close(output, origin_output, rtol=1e-03, atol=1e-05)
|
||||
print("*******************backward starting*******************")
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
output.sum().backward()
|
||||
|
@@ -60,7 +60,7 @@ def check_apply(rank, world_size, port):
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
|
@@ -3,8 +3,13 @@ from torch.fx import GraphModule
|
||||
from torchvision.models import resnet50
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
|
||||
StrategiesConstructor)
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
GraphAnalyser,
|
||||
Solver,
|
||||
SolverOptions,
|
||||
StrategiesConstructor,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
@@ -53,7 +58,7 @@ def test_cost_graph():
|
||||
gm.recompile()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
|
Reference in New Issue
Block a user