[autoparallel] support origin activation ckpt on autoprallel system (#2468)

This commit is contained in:
YuliangLiu0306 2023-01-16 16:25:13 +08:00 committed by GitHub
parent 3a21485ead
commit 67e1912b59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 111 additions and 5 deletions

View File

@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply, runtime_apply,
args=(node, origin_dict_node, input_dict_node, args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index)) node_to_index_dict[node], user_node_index))
if 'activation_checkpoint' in user_node.meta:
shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint']
new_args = list(user_node.args) new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs) new_kwargs = dict(user_node.kwargs)
@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node # substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs user.kwargs = new_kwargs
if 'activation_checkpoint' in node.meta:
comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
return gm
def _act_annotataion_pass(gm: torch.fx.GraphModule):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
for node in nodes:
if not hasattr(node.meta, 'activation_checkpoint'):
from .runtime_preparation_pass import size_processing
user_act_annotation = -1
input_act_annotation = -1
for user_node in node.users.keys():
if 'activation_checkpoint' in user_node.meta:
user_act_annotation = user_node.meta['activation_checkpoint']
break
for input_node in node._input_nodes.keys():
if 'activation_checkpoint' in input_node.meta:
input_act_annotation = input_node.meta['activation_checkpoint']
break
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
node.meta['activation_checkpoint'] = user_act_annotation
return gm return gm

View File

@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# It will be used to replace the original node with processing node in slice object # It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data size_processing_node._meta_data = node._meta_data
if 'activation_checkpoint' in node.meta:
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
user_list = list(node.users.keys()) user_list = list(node.users.keys())
for user in user_list: for user in user_list:

View File

@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
) )
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module):
into the forward function. into the forward function.
''' '''
def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]): origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
''' '''
Args: Args:
@ -81,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
return strategies_constructor return strategies_constructor
def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
''' '''
This method is used to solve the best solution for the given graph. This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node. The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
@ -97,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor,
return solution return solution
def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh, def transform_to_sharded_model(gm: ColoGraphModule, solution: List[int], device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor): strategies_constructor: StrategiesConstructor):
''' '''
This method is used to transform the original graph to the sharded graph. This method is used to transform the original graph to the sharded graph.
@ -197,10 +198,10 @@ def initialize_model(model: nn.Module,
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies. return a series of integers, but return the best strategies.
''' '''
tracer = ColoTracer() tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(root=model, meta_args=meta_args) graph = tracer.trace(root=model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__) gm = ColoGraphModule(model, graph, model.__class__.__name__)
gm.recompile() gm.recompile()
strategies_constructor = build_strategy_constructor(graph, device_mesh) strategies_constructor = build_strategy_constructor(graph, device_mesh)
if load_solver_solution: if load_solver_solution:

View File

@ -0,0 +1,70 @@
from functools import partial
from typing import Optional, Tuple, Union
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from transformers.pytorch_utils import Conv1D
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
HIDDEN_SIZE = 16
class GPT2MLPWithCkpt(nn.Module):
def __init__(self, intermediate_size, hidden_size):
super().__init__()
embed_dim = hidden_size
self.c_fc = Conv1D(intermediate_size, embed_dim)
self.c_proj = Conv1D(embed_dim, intermediate_size)
self.act = torch.nn.ReLU()
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = checkpoint(self.c_proj, hidden_states)
hidden_states = self.act(hidden_states)
return hidden_states
def check_act_ckpt(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE)
input_sample = {
'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'),
}
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
gm = initialize_model(model, input_sample, device_mesh)
code = gm.module.graph.python_code('self').src
assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code
assert "view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)" in code
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_mlp_layer():
world_size = 4
run_func = partial(check_act_ckpt, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_mlp_layer()