diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 05e7615d8..20d9d7c38 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -19,5 +19,5 @@ __all__ = [ 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', - 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler' + 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetattrHandler' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py index 489e40daf..d2edfa83c 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py @@ -2,7 +2,9 @@ from typing import Dict, List import torch -from ..sharding_strategy import OperationData, OperationDataType +from colossalai.device.device_mesh import DeviceMesh + +from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector from .node_handler import NodeHandler from .strategy import OutputGenerator, StrategyGenerator @@ -14,26 +16,37 @@ class OuputHandler(NodeHandler): A OuputHandler which deals with the sharding strategies for Output Node. """ + def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, + output_option: str) -> None: + super().__init__(node, device_mesh, strategies_vector) + self.output_option = output_option + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] - generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node)) + generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node, self.output_option)) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - dummy_output = torch.empty(1,).to("meta") - physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=dummy_output) - - mapping = {"output": physical_output} + mapping = {} + output_meta_data = [] for index, input_node in enumerate(self.predecessor_node): - if not hasattr(input_node, "_meta_data"): - print(input_node.name) - physical_inputs = OperationData(name=str(input_node), - type=OperationDataType.ARG, - data=input_node._meta_data) + input_meta_data = input_node._meta_data + physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data) name_key = f'input_{index}' mapping[name_key] = physical_inputs + output_meta_data.append(input_meta_data) + assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.' + if len(output_meta_data) == 1: + output_meta_data = output_meta_data[0] + else: + output_meta_data = tuple(output_meta_data) + + self.node._meta_data = output_meta_data + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping["output"] = physical_output return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py index 88a02428e..c72a5d3bf 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py @@ -1,6 +1,10 @@ from typing import Dict, List -from ..sharding_strategy import OperationData, OperationDataType +from torch.fx.node import Node + +from colossalai.device.device_mesh import DeviceMesh + +from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector from .node_handler import NodeHandler from .strategy import PlaceholderGenerator, StrategyGenerator @@ -12,10 +16,16 @@ class PlacehodlerHandler(NodeHandler): A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node. """ + def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, + placeholder_option: str) -> None: + super().__init__(node, device_mesh, strategies_vector) + self.placeholder_option = placeholder_option + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] - generators.append(PlaceholderGenerator(op_data_mapping, self.device_mesh)) + generators.append( + PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py index de9dfba67..b9512887c 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py @@ -1,6 +1,14 @@ -from typing import List +from typing import Dict, List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from torch.fx import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh from .strategy_generator import OutputStrategyGenerator @@ -12,6 +20,11 @@ class OutputGenerator(OutputStrategyGenerator): OutputGenerator is a generic class to generate strategies for Output Node. """ + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + predecessor_nodes: List[Node], output_option: str): + super().__init__(operation_data_mapping, device_mesh, predecessor_nodes) + self.output_option = output_option + def validate(self) -> bool: return super().validate() @@ -32,7 +45,10 @@ class OutputGenerator(OutputStrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost - def collate_strategies(self) -> List[ShardingStrategy]: + def replica_strategy(self) -> List[ShardingStrategy]: + """ + Generate replica strategy for output node. + """ dim_partition_dict_mapping = { "output": {}, } @@ -48,5 +64,47 @@ class OutputGenerator(OutputStrategyGenerator): strategy = self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + return strategy - return [strategy] + def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]: + """ + Generate distributed strategy for output node. + """ + # TODO: need to take care of the case when the first element of output only need to be sharded. + output_op_data = self.op_data['output'] + if isinstance(output_op_data.data, tuple): + length = len(output_op_data.data) + dim_partition_dict_mapping = { + "output": [{ + 0: mesh_list + }] * length, + } + else: + dim_partition_dict_mapping = { + "output": { + 0: mesh_list + }, + } + for index, _ in enumerate(self.predecessor_nodes): + mapping_name = f"input_{index}" + dim_partition_dict_mapping[mapping_name] = {0: mesh_list} + + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Distributed Output' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + return strategy + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + mesh_list = [0, 1] + if self.output_option == 'replicated': + strategy_list.append(self.replica_strategy()) + elif self.output_option == 'distributed': + strategy_list.append(self.distributed_strategy(mesh_list)) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py index 9023ab0fb..779a7ced9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py @@ -1,6 +1,12 @@ -from typing import List +from typing import Dict, List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh from .strategy_generator import StrategyGenerator @@ -12,6 +18,11 @@ class PlaceholderGenerator(StrategyGenerator): PlaceholderGenerator is a generic class to generate strategies for placeholder node. """ + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + placeholder_option: str): + super().__init__(operation_data_mapping, device_mesh) + self.placeholder_option = placeholder_option + def validate(self) -> bool: return super().validate() @@ -37,7 +48,10 @@ class PlaceholderGenerator(StrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost - def collate_strategies(self) -> List[ShardingStrategy]: + def replica_placeholder(self) -> ShardingStrategy: + """ + Generate replica strategy for placeholder node. + """ dim_partition_dict_mapping = { "output": {}, } @@ -50,4 +64,37 @@ class PlaceholderGenerator(StrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - return [strategy] + return strategy + + def distributed_placeholder(self, mesh_list) -> ShardingStrategy: + """ + Generate distributed strategy for placeholder node. + """ + dim_partition_dict_mapping = { + "output": { + 0: mesh_list + }, + } + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Distributed Placeholder' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + if self.placeholder_option == 'distributed': + mesh_list = [0, 1] + distributed_strategy = self.distributed_placeholder(mesh_list) + strategy_list.append(distributed_strategy) + else: + assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported' + replicated_strategy = self.replica_placeholder() + strategy_list.append(replicated_strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index c0f7a33da..d67ef1f49 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -73,7 +73,10 @@ class StrategyGenerator(ABC): for op_data_name, dim_partition_dict in mapping.items(): if op_data_name in self.op_data: op_data = self.op_data[op_data_name] - if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor): + if isinstance(op_data.data, tuple): + for data in op_data.data: + assert isinstance( + data, torch.Tensor), 'We cannot create a ShardingSpec object from a non-tensor object.' sharding_spec = [] for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict): dim_size = len(logical_shape) @@ -82,6 +85,9 @@ class StrategyGenerator(ABC): entire_shape=logical_shape, dim_partition_dict=dim_partition_dict_element) else: + assert isinstance( + op_data.data, torch.Tensor + ), f'op_data.data should be a torch.Tensor or Tuple[torch.Tensor], but got {type(op_data.data)}' dim_size = len(op_data.logical_shape) dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict) sharding_spec = ShardingSpec(device_mesh=self.device_mesh, diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py index 415a1de9e..a70c87a13 100644 --- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -43,8 +43,11 @@ class OperationData: def __post_init__(self): # if no logical shape is specified, use the data shape as the logical shape - if self.logical_shape is None and isinstance(self.data, torch.Tensor): - self.logical_shape = self.data.shape + if self.logical_shape is None: + if isinstance(self.data, torch.Tensor): + self.logical_shape = self.data.shape + elif isinstance(self.data, tuple): + self.logical_shape = tuple([getattr(d, 'shape', None) for d in self.data]) def __repr__(self) -> str: return f'OperationData(name={self.name}, type={self.type})' diff --git a/colossalai/auto_parallel/tensor_shard/solver/options.py b/colossalai/auto_parallel/tensor_shard/solver/options.py index 2d34f5c64..b52e55708 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/options.py +++ b/colossalai/auto_parallel/tensor_shard/solver/options.py @@ -1,11 +1,30 @@ from dataclasses import dataclass +from enum import Enum __all__ = ['SolverOptions'] +class SolverPerference(Enum): + """ + This enum class is to define the solver preference. + """ + STANDARD = 0 + DP = 1 + TP = 2 + + +class DataloaderOption(Enum): + """ + This enum class is to define the dataloader option. + """ + REPLICATED = 0 + DISTRIBUTED = 1 + + @dataclass class SolverOptions: """ SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. """ - fast: bool = False + solver_perference: SolverPerference = SolverPerference.STANDARD + dataloader_option: DataloaderOption = DataloaderOption.REPLICATED diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 48035e6b8..b934ef2ea 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -6,15 +6,16 @@ from typing import Dict, List import torch from torch.fx import Graph, Node -from colossalai.auto_parallel.tensor_shard.node_handler import OuputHandler, PlacehodlerHandler, operator_registry -from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector -from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec +from colossalai.auto_parallel.tensor_shard.node_handler import ( + GetattrHandler, + OuputHandler, + PlacehodlerHandler, + operator_registry, +) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec -from .options import SolverOptions +from .options import DataloaderOption, SolverOptions __all__ = ['StrategiesConstructor'] @@ -67,7 +68,15 @@ class StrategiesConstructor: strategies_vector = StrategiesVector(node) # placeholder node if node.op == 'placeholder': - placeholder_handler = PlacehodlerHandler(node, self.device_mesh, strategies_vector) + if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: + placeholder_option = 'distributed' + else: + assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' + placeholder_option = 'replicated' + placeholder_handler = PlacehodlerHandler(node, + self.device_mesh, + strategies_vector, + placeholder_option=placeholder_option) placeholder_handler.register_strategy() # get_attr node @@ -97,7 +106,12 @@ class StrategiesConstructor: # output node elif node.op == 'output': - output_handler = OuputHandler(node, self.device_mesh, strategies_vector) + if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: + output_option = 'distributed' + else: + assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' + output_option = 'replicated' + output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) output_handler.register_strategy() if len(strategies_vector) <= 0: diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py index c7c166626..e666cb175 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -84,7 +84,7 @@ def check_linear_module(rank, world_size, port): gm.recompile() node_list = list(graph.nodes) - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() linear_node = node_list[3] @@ -138,7 +138,7 @@ def check_conv_module(rank, world_size, port): node_list = list(graph.nodes) conv_node = node_list[3] - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py index 3f0dfdf3f..04d589ab3 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py @@ -36,7 +36,7 @@ def mem_test_for_node_strategy(rank: int, input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') graph = tracer.trace(root=model_to_shard, meta_args=input_sample) gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() target_node = list(graph.nodes)[node_index] diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py index 27b0af4fb..16eb98300 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py @@ -1,11 +1,11 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import \ - OuputHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OuputHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use class OutputModel(nn.Module): @@ -18,7 +18,9 @@ class OutputModel(nn.Module): return x, y -def test_output_handler(): +@parameterize('output_option', ['distributed', 'replicated']) +@rerun_if_address_is_in_use() +def test_output_handler(output_option): model = OutputModel() tracer = ColoTracer() # graph(): @@ -37,7 +39,10 @@ def test_output_handler(): output_strategies_vector = StrategiesVector(output_node) # build handler - otuput_handler = OuputHandler(node=output_node, device_mesh=device_mesh, strategies_vector=output_strategies_vector) + otuput_handler = OuputHandler(node=output_node, + device_mesh=device_mesh, + strategies_vector=output_strategies_vector, + output_option=output_option) otuput_handler.register_strategy(compute_resharding_cost=False) # check operation data mapping @@ -49,10 +54,12 @@ def test_output_handler(): assert op_data.data is not None assert mapping['output'].name == "output" - assert mapping['output'].data.is_meta assert mapping['output'].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in otuput_handler.strategies_vector] - assert "Replica Output" in strategy_name_list + if output_option == 'distributed': + assert "Distributed Output" in strategy_name_list + else: + assert "Replica Output" in strategy_name_list if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py index bdec901e9..0aafb9e0b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py @@ -1,11 +1,11 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import \ - PlacehodlerHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use class PlaceholderModel(nn.Module): @@ -17,7 +17,9 @@ class PlaceholderModel(nn.Module): return input -def test_placeholder_handler(): +@parameterize('placeholder_option', ['distributed', 'replicated']) +@rerun_if_address_is_in_use() +def test_placeholder_handler(placeholder_option): model = PlaceholderModel() tracer = ColoTracer() # graph(): @@ -33,16 +35,25 @@ def test_placeholder_handler(): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) placeholder_node = list(graph.nodes)[0] placeholder_strategies_vector = StrategiesVector(placeholder_node) - # build handler placeholder_handler = PlacehodlerHandler(node=placeholder_node, device_mesh=device_mesh, - strategies_vector=placeholder_strategies_vector) + strategies_vector=placeholder_strategies_vector, + placeholder_option=placeholder_option) placeholder_handler.register_strategy(compute_resharding_cost=False) + # check operation data mapping mapping = placeholder_handler.get_operation_data_mapping() + strategy = placeholder_strategies_vector[0] + strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping['output'].name) + + if placeholder_option == 'distributed': + assert str(strategy_sharding_spec.sharding_sequence) == '[S01, R, R, R]' + else: + assert str(strategy_sharding_spec.sharding_sequence) == '[R, R, R, R]' + for name, op_data in mapping.items(): op_data: OperationData # make sure they have valid values @@ -53,7 +64,10 @@ def test_placeholder_handler(): assert mapping['output'].data.shape == torch.Size((4, 4, 64, 64)) assert mapping['output'].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in placeholder_handler.strategies_vector] - assert "Replica Placeholder" in strategy_name_list + if placeholder_option == 'replicated': + assert "Replica Placeholder" in strategy_name_list + else: + assert "Distributed Placeholder" in strategy_name_list if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py index b39a7b0cc..a89b73958 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -79,7 +79,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') graph = tracer.trace(root=model_to_shard, meta_args=input_sample) gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() target_node = list(graph.nodes)[node_index] diff --git a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py index b67641f61..611402fe8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py @@ -79,7 +79,7 @@ def test_linear_module(): gm.recompile() node_list = list(graph.nodes) - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() linear_node = node_list[3] @@ -117,7 +117,7 @@ def test_conv_module(): gm.recompile() node_list = list(graph.nodes) conv_node = node_list[3] - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() _param_resharding_cost_assertion(conv_node) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py b/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py index cb8037627..814edd279 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py @@ -138,7 +138,7 @@ def check_apply_bottleneck(rank, world_size, port): graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() @@ -162,7 +162,7 @@ def check_apply_bottleneck(rank, world_size, port): output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) assert output.shape == origin_output.shape - assert_close(output, origin_output) + assert_close(output, origin_output, rtol=1e-03, atol=1e-05) print("*******************backward starting*******************") cuda_rng_state = torch.cuda.get_rng_state() output.sum().backward() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py index 7a1c882f6..66cd3f3f7 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py @@ -60,7 +60,7 @@ def check_apply(rank, world_size, port): graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index 23d866bbe..f4a5ae7ac 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -3,8 +3,13 @@ from torch.fx import GraphModule from torchvision.models import resnet50 from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP -from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions, - StrategiesConstructor) +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.shape_consistency import ShapeConsistencyManager @@ -53,7 +58,7 @@ def test_cost_graph(): gm.recompile() graph_analyser = GraphAnalyser(gm) liveness_list = graph_analyser.liveness_analysis() - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost()