diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index cc2466273..9f95009d9 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -93,7 +93,7 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): # substitute the origin node with shape_consistency_node origin_index_args = new_args.index(node) new_args[origin_index_args] = shape_consistency_node - user_node.args = new_args + user_node.args = tuple(new_args) elif str(node) in new_kwargs: # substitute the origin node with shape_consistency_node new_kwargs[str(node)] = shape_consistency_node @@ -118,10 +118,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): comm_actions = node.best_strategy.communication_actions for op_data, comm_action in comm_actions.items(): - if op_data.type == OperationDataType.PARAM: + if comm_action.comm_type == CommType.HOOK: continue if comm_action.comm_type == CommType.BEFORE: - if comm_action.key_for_kwarg is not None: + if op_data.type == OperationDataType.OUTPUT: + comm_object = node + elif comm_action.key_for_kwarg is not None: comm_object = node.kwargs[comm_action.key_for_kwarg] else: comm_object = node.args[comm_action.arg_index] @@ -140,7 +142,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): # substitute the origin node with comm_spec_apply_node new_args = list(node.args) new_args[comm_action.arg_index] = comm_spec_apply_node - node.args = new_args + node.args = tuple(new_args) elif comm_action.comm_type == CommType.AFTER: with mod_graph.inserting_after(node): @@ -163,7 +165,6 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): # substitute the origin node with comm_spec_apply_node new_kwargs[str(node)] = comm_spec_apply_node user.kwargs = new_kwargs - return gm diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 00268e3f5..df2d30cbc 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -5,7 +5,12 @@ import torch from torch.fx import symbolic_trace from torch.fx.node import Node -from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationDataType, + ShardingStrategy, +) from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.comm_spec import _all_reduce from colossalai.tensor.shape_consistency import ShapeConsistencyManager @@ -42,7 +47,32 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]): target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name)) target_sharding_specs.append(target_sharding_spec) sharding_spec_convert_dict[index] = target_sharding_specs - + # the get_attr node strategy is kind of pending strategy, which means we will change it + # to the same strategy of the user node. + if node.op == 'get_attr': + assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.' + new_sharding_spec = target_sharding_specs[0] + user_node = node.strategies_vector.successor_nodes[0] + user_strategy = node.strategies_vector.successor_nodes[0].best_strategy + op_data_in_user = user_strategy.get_op_data_by_name(str(node)) + origin_node_sharding_spec_dict[index] = new_sharding_spec + origin_pending_strategy = node.best_strategy + origin_op_data = origin_pending_strategy.get_op_data_by_name(str(node)) + new_sharding_specs = origin_pending_strategy.sharding_specs + new_sharding_specs[origin_op_data] = new_sharding_spec + new_communication_actions = {} + if op_data_in_user in user_strategy.communication_actions: + new_communication_action = user_strategy.communication_actions.pop(op_data_in_user) + new_communication_action.arg_index = 0 + new_communication_actions[origin_op_data] = new_communication_action + new_strategy = ShardingStrategy(name=str(new_sharding_spec.sharding_sequence), + sharding_specs=new_sharding_specs, + compute_cost=origin_pending_strategy.compute_cost, + communication_cost=origin_pending_strategy.communication_cost, + memory_cost=origin_pending_strategy.memory_cost, + communication_actions=new_communication_actions) + setattr(node, 'best_strategy', new_strategy) + setattr(node, 'sharding_spec', new_sharding_spec) comm_action_dict = {} for op_data, comm_action in node.best_strategy.communication_actions.items(): comm_action_dict[op_data.name] = comm_action @@ -111,6 +141,43 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh): for name, buffer_sharded in sharded_buffer_dict.items(): setattr(target_module, name, buffer_sharded.detach().clone()) + if node.op == 'get_attr': + root = node.graph.owning_module + atoms = node.target.split(".") + attr_len = len(atoms) + if attr_len == 1: + target_module = root + target = getattr(root, atoms[0]) + else: + target_module = root.get_submodule(atoms[-2]) + target = getattr(target_module, atoms[-1]) + + target_sharding_spec = node.sharding_spec + if target_sharding_spec.dim_partition_dict != {}: + origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {}) + setattr(target, 'sharding_spec', origin_sharding_spec) + # TODO: build a ColoParamter class to manager the distributed parameters + target_sharded = torch.nn.Parameter( + shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec, + target_sharding_spec).detach().clone()) + else: + target_sharded = target + setattr(target_module, atoms[-1], target_sharded) + + comm_actions = node.best_strategy.communication_actions + for operation_data, comm_action in comm_actions.items(): + comm_spec_to_use = comm_action.comm_spec + # register hook to the parameters + if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: + + def wrapper(param, comm_spec): + + def hook_fn(grad): + _all_reduce(grad, comm_spec) + + param.register_hook(hook_fn) + + wrapper(target_sharded, comm_spec_to_use) return gm diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index 3c4c05786..d6a06bc15 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -29,8 +29,15 @@ class ReshapeHandler(NodeHandler): def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process + + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, + type=data_type, data=self.node.args[0]._meta_data) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py index cbe0f0746..0b3506c27 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py @@ -96,7 +96,7 @@ class ReshapeGenerator(FollowingStrategyGenerator): arg_index=0) input_comm_action.comm_spec.gather_dim = total_mesh_dim_list - else: + elif len(total_mesh_dim_list) >= 2: source_spec = sharding_spec_mapping["input"] target_spec = ShardingSpec(device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, @@ -104,7 +104,11 @@ class ReshapeGenerator(FollowingStrategyGenerator): comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) - communication_action_mapping["input"] = input_comm_action + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action strategy = self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py index 21695f6b5..4b6c82a74 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py @@ -43,7 +43,7 @@ class BiasAdditionConv(BiasAdditionModule): bias_shape[0] = -1 bias_reshape_node_kind = 'call_method' bias_reshape_node_target = 'view' - bias_reshape_node_args = (self.bias_proxy, bias_shape) + bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape)) bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target, bias_reshape_node_args, {}) return bias_reshape_proxy diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py index 493c57023..aba254a80 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py @@ -58,7 +58,7 @@ def torch_bmm(input, mat2, *, out=None): @meta_patched_function.register(torch.nn.functional.linear) -def torch_linear(input, mat2, *, out=None): +def torch_linear(input, mat2, bias=None, *, out=None): if out is not None: raise ValueError("Don't support in-place abs for MetaTensor analysis") output_shape = list(input.shape) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py new file mode 100644 index 000000000..c7c166626 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -0,0 +1,172 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port + + +class LinearModel(torch.nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + x = self.linear(x) + x = x * 2 + + return x + + +class ConvModel(torch.nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias) + + def forward(self, x): + x = self.conv(x) + x = x * 2 + + return x + + +def check_linear_module(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModel(4, 8).cuda() + input = torch.rand(4, 4).cuda() + output_compare = model(input) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %linear_weight : [#users=1] = get_attr[target=linear.weight] + # %linear_bias : [#users=1] = get_attr[target=linear.bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')}) + # def forward(self, x : torch.Tensor): + # linear_weight = self.linear.weight + # linear_bias = self.linear.bias + # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None + # add = linear + linear_bias; linear = linear_bias = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + gm.recompile() + node_list = list(graph.nodes) + + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + linear_node = node_list[3] + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) + + gm = runtime_apply_pass(gm) + gm.recompile() + output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + assert_close(output, output_compare) + + +def check_conv_module(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = ConvModel(3, 6, 2).cuda() + input = torch.rand(4, 3, 64, 64).cuda() + output_compare = model(input) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) + # def forward(self, x : torch.Tensor): + # conv_weight = self.conv.weight + # conv_bias = self.conv.bias + # conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None + # view = conv_bias.view([1, -1, 1, 1]); conv_bias = None + # add = conv2d + view; conv2d = view = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + + gm.recompile() + + node_list = list(graph.nodes) + conv_node = node_list[3] + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) + + gm = runtime_apply_pass(gm) + gm.recompile() + output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + assert_close(output, output_compare) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bias_addition_module(): + world_size = 4 + run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port()) + mp.spawn(run_func_linear, nprocs=world_size) + run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port()) + mp.spawn(run_func_conv, nprocs=world_size) + + +if __name__ == '__main__': + test_bias_addition_module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py new file mode 100644 index 000000000..1bc556209 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -0,0 +1,146 @@ +from faulthandler import disable +from functools import partial +from xml.dom import WrongDocumentErr + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from typing_extensions import Self + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy + + +class LinearModule(torch.nn.Module): + + def __init__(self, in_features, out_features, bias): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x): + x = self.linear(x) + return x + + +def check_linear_module_handler(rank, bias, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = LinearModule(16, 32, bias=bias).cuda() + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + input = torch.rand(2, 2, 4, 16).cuda() + # the index of linear node in computation graph + node_index = 3 + # strategy number of linear node + strategy_number = 10 + # construct input args + input_args = [input] + # construct meta arg names + meta_arg_names = ['x'] + numerical_test_for_node_strategy(model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type='bias_module') + + tracer = ColoTracer() + graph = tracer.trace(model, meta_args={"x": torch.rand(2, 2, 4, 16).to('meta')}) + gm = ColoGraphModule(model, graph) + + linear_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = LinearFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x" + assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([16, 16]) + + assert mapping['other'].name == "linear_weight" + assert mapping['other'].data.shape == torch.Size([32, 16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16, 32]) + + assert 'bias' not in mapping + + assert mapping['output'].name == "linear" + assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + # one strategy will be converted to different physical sharding spec + assert len(strategy_name_list) > 8 + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + for strategy in strategies_vector: + strategy: ShardingStrategy + input_sharding_spec = strategy.get_sharding_spec_by_name('x') + weight_sharding_spec = strategy.get_sharding_spec_by_name('linear_weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_handler(bias=True): + world_size = 4 + run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +if __name__ == '__main__': + test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py index d59c10707..d871db144 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -7,6 +7,9 @@ from torch.fx import GraphModule from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.solver.solver import Solver from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.shape_consistency import to_global @@ -56,7 +59,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, strategy_number: int, input_args: List[torch.Tensor], meta_arg_names: List[str], - input_kwargs: Dict[str, torch.Tensor] = {}): + input_kwargs: Dict[str, torch.Tensor] = {}, + node_type: str = 'normal'): for strategy_index in range(strategy_number): print(f'#strategy_index: {strategy_index}') # We need to copy the model to avoid do backward more than once in same graph @@ -79,11 +83,21 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() target_node = list(graph.nodes)[node_index] - - # solution construction - solution_len = len(strategies_constructor.leaf_strategies) - solution = [0] * solution_len - solution[node_index] = strategy_index + if node_type == 'normal': + solution_len = len(strategies_constructor.leaf_strategies) + solution = [0] * solution_len + solution[node_index] = strategy_index + else: + node_vector = strategies_constructor.leaf_strategies[node_index] + strategy_to_keep = node_vector[strategy_index] + node_vector = [strategy_to_keep] + # solution construction + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + graph_analyser = GraphAnalyser(gm) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( gm, solution, device_mesh) gm = runtime_apply_pass(gm) @@ -110,11 +124,18 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, # extract the strategy used in this iter strategy_in_use = target_node.strategies_vector[strategy_index] - param_to_shard_dict = dict(model_to_shard.named_parameters()) + param_to_shard_dict = dict(gm.named_parameters()) param_to_compare_dict = dict(model_to_compare.named_parameters()) for name in param_to_shard_dict.keys(): param_name = name.split('.')[-1] - param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name) + if node_type == 'normal': + param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name) + else: + if 'weight' in name: + param_sharding_spec = list(graph.nodes)[4].sharding_spec + elif 'bias' in name: + param_sharding_spec = list(graph.nodes)[5].sharding_spec + grad_sharded = param_to_shard_dict[name].grad grad_to_compare = param_to_compare_dict[name].grad global_grad = to_global(grad_sharded, param_sharding_spec)