diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index 7f2aac42b..9d83f1057 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): runtime_apply, args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index)) + if 'activation_checkpoint' in user_node.meta: + shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint'] new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs) @@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): # substitute the origin node with comm_spec_apply_node new_kwargs[str(node)] = comm_spec_apply_node user.kwargs = new_kwargs + + if 'activation_checkpoint' in node.meta: + comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint'] + + return gm + + +def _act_annotataion_pass(gm: torch.fx.GraphModule): + """ + This pass is used to add the act annotation to the new inserted nodes. + """ + mod_graph = gm.graph + nodes = tuple(mod_graph.nodes) + + for node in nodes: + if not hasattr(node.meta, 'activation_checkpoint'): + from .runtime_preparation_pass import size_processing + + user_act_annotation = -1 + input_act_annotation = -1 + for user_node in node.users.keys(): + if 'activation_checkpoint' in user_node.meta: + user_act_annotation = user_node.meta['activation_checkpoint'] + break + for input_node in node._input_nodes.keys(): + if 'activation_checkpoint' in input_node.meta: + input_act_annotation = input_node.meta['activation_checkpoint'] + break + if user_act_annotation == input_act_annotation and user_act_annotation != -1: + node.meta['activation_checkpoint'] = user_act_annotation + return gm diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index f9b890263..1c25e4c94 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): # It will be used to replace the original node with processing node in slice object node_pairs[node] = size_processing_node size_processing_node._meta_data = node._meta_data + if 'activation_checkpoint' in node.meta: + size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint'] user_list = list(node.users.keys()) for user in user_list: diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index 8c24c0d7b..387a682a1 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import ( ) from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec @@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module): into the forward function. ''' - def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], + def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]): ''' Args: @@ -81,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh): return strategies_constructor -def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): +def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): ''' This method is used to solve the best solution for the given graph. The solution is a list of integers, each integer represents the best strategy index of the corresponding node. @@ -97,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, return solution -def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh, +def transform_to_sharded_model(gm: ColoGraphModule, solution: List[int], device_mesh: DeviceMesh, strategies_constructor: StrategiesConstructor): ''' This method is used to transform the original graph to the sharded graph. @@ -197,10 +198,10 @@ def initialize_model(model: nn.Module, solution will be used to debug or help to analyze the sharding result. Therefore, we will not just return a series of integers, but return the best strategies. ''' - tracer = ColoTracer() + tracer = ColoTracer(trace_act_ckpt=True) graph = tracer.trace(root=model, meta_args=meta_args) - gm = GraphModule(model, graph, model.__class__.__name__) + gm = ColoGraphModule(model, graph, model.__class__.__name__) gm.recompile() strategies_constructor = build_strategy_constructor(graph, device_mesh) if load_solver_solution: diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py new file mode 100644 index 000000000..0b42722fe --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py @@ -0,0 +1,70 @@ +from functools import partial +from typing import Optional, Tuple, Union + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from transformers.pytorch_utils import Conv1D + +from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.tracer import ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.utils import free_port + +HIDDEN_SIZE = 16 + + +class GPT2MLPWithCkpt(nn.Module): + + def __init__(self, intermediate_size, hidden_size): + super().__init__() + embed_dim = hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = torch.nn.ReLU() + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = checkpoint(self.c_proj, hidden_states) + hidden_states = self.act(hidden_states) + + return hidden_states + + +def check_act_ckpt(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE) + input_sample = { + 'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'), + } + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + gm = initialize_model(model, input_sample, device_mesh) + code = gm.module.graph.python_code('self').src + assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code + assert "view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)" in code + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_mlp_layer(): + world_size = 4 + run_func = partial(check_act_ckpt, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_mlp_layer()