diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py index 44449ff8e..4a1b8ab26 100644 --- a/colossalai/fx/passes/shard_1d_pass.py +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -1,9 +1,16 @@ import torch +import torch.nn as nn import operator -import colossalai +from colossalai.tensor import ProcessGroup +from colossalai.tensor.distspec import shard +from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec + +ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] +ELEMENTWISE_FUNC_OP = [ + torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, + operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout +] -ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.AvgPool1d, torch.nn.AvgPool2d] -ELEMENTWISE_FUNC_OP = [torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d, torch.nn.functional.avg_pool1d, torch.nn.functional.avg_pool2d, torch.nn.functional.avg_pool3d, torch.nn.functional.max_pool1d, torch.nn.functional.max_pool2d, torch.nn.functional.max_pool3d] def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter: """weight_split @@ -21,6 +28,8 @@ def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: boo else: setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs")) return weight + + def column_shard_linear_pass(gm: torch.fx.GraphModule): # Split all the linear module with column shard. Currently for testing only. mod_graph = gm.graph @@ -48,43 +57,95 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule): gm.recompile() return gm -def transform_mlp_pass(gm: torch.fx.GraphModule): + +def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup): + """ + This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers. + """ #TODO: Needs to handle special cases, like x = linear(x) + linear(x) - mod_graph = gm.graph - col_shard = True - element_op = [] - all_linear_name = [] - linear_name = [] - # Get the name of element wise module(torch.nn.ReLU) - # Get the name of all the linear modules and repeated linear modules - for name, func in gm.named_children(): - if not isinstance(func, torch.nn.Linear): - for i in ELEMENTWISE_MODULE_OP: - if isinstance(func, i): - element_op.append(name) - break - else: - if name in all_linear_name: - if name in linear_name: - linear_name.remove(name) - else: - all_linear_name.append(name) - linear_name.append(name) - # If the linear modules is called multiple times, set the dist spec as col shard - # If the module is element wise or the function/method is element wise, remains col_shard - for node in mod_graph.nodes: - if node.target in linear_name: - target_module = node.graph.owning_module.get_submodule(node.target) - dim = 0 if col_shard else -1 - target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=False) - col_shard = not col_shard - elif node.target in all_linear_name: - target_module = node.graph.owning_module.get_submodule(node.target) - dim = 0 if col_shard else -1 - target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=True) - col_shard = not col_shard - else: - if node.target not in element_op and all(node.target != i for i in ELEMENTWISE_FUNC_OP): - col_shard = True - gm.recompile() - return gm \ No newline at end of file + graph = graph_module.graph + world_size = process_group.world_size() + + def _traverse_and_annotate(node, start_tracking, annotation_record, world_size): + # traverse the graph to look for consecutive linear layers + is_linear_module = False + + if node.op == 'call_module': + # look for the linear layer + module = node.graph.owning_module.get_submodule(node.target) + if isinstance(module, nn.Linear): + is_linear_module = True + if start_tracking: + # when start_tracking = True + # it means the first linear has been found and the current module + # is the second linear + # set the current linear module to be row-sharded + annotation_record['row'] = module + + for shard_type, module in annotation_record.items(): + # add row sharding spec + if shard_type == 'row': + dist_spec = shard(dims=[-1], num_partitions=[world_size]) + comp_spec = ComputeSpec(ComputePattern.TP1D) + setattr(module.weight, 'pg', process_group) + setattr(module.weight, 'dist_spec', dist_spec) + setattr(module.weight, 'comp_spec', comp_spec) + elif shard_type == 'col': + weight_dist_spec = shard(dims=[0], num_partitions=[world_size]) + weight_comp_spec = ComputeSpec(ComputePattern.TP1D) + weight_comp_spec.output_replicate = False + setattr(module.weight, 'pg', process_group) + setattr(module.weight, 'dist_spec', weight_dist_spec) + setattr(module.weight, 'comp_spec', weight_comp_spec) + + if module.bias is not None: + bias_dist_spec = shard(dims=[0], num_partitions=[world_size]) + bias_comp_spec = ComputeSpec(ComputePattern.TP1D) + bias_comp_spec.output_replicate = False + setattr(module.bias, 'pg', process_group) + setattr(module.bias, 'dist_spec', bias_dist_spec) + setattr(module.bias, 'comp_spec', bias_comp_spec) + start_tracking = False + annotation_record.clear() + else: + # when start tracking = False + # it means the current layer is the first linear + # set the linear layer to be col-sharded + start_tracking = True + annotation_record['col'] = module + + if start_tracking and not is_linear_module: + # check against the white list + # if non-element wise op is found, we reset the tracking + if node.op == 'call_module': + module = node.graph.owning_module.get_submodule(node.target) + if module.__class__ not in ELEMENTWISE_MODULE_OP: + start_tracking = False + elif node.op == 'call_function' or node.op == 'call_method': + if node.target not in ELEMENTWISE_FUNC_OP: + start_tracking = False + elif len(node.users.keys()) > 1: + start_tracking = False + + if not start_tracking: + annotation_record.clear() + + # stop tracking for consecutive linear when branch is found + # e.g. + # out1 = self.linear1(x) + # out2 = self.linear2(x) + # return out1+out2 + next_nodes = list(node.users.keys()) + if len(next_nodes) > 1: + start_tracking = False + annotation_record.clear() + + # traverse + for node in next_nodes: + _traverse_and_annotate(node, start_tracking, annotation_record, world_size) + + placeholder_node = list(graph.nodes)[0] + annotate_record = {} + _traverse_and_annotate(placeholder_node, False, annotate_record, world_size) + + return graph_module diff --git a/colossalai/utils/model/lazy_init_context.py b/colossalai/utils/model/lazy_init_context.py index 142ced630..057cbd8f5 100644 --- a/colossalai/utils/model/lazy_init_context.py +++ b/colossalai/utils/model/lazy_init_context.py @@ -175,7 +175,7 @@ class LazyInitContext(): self._unpatch_nn_init_funcs() self._unpatch_torch_tensor_funcs() - def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None): + def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'): """ Initialize the weights of the meta-tensor model. @@ -205,6 +205,7 @@ class LazyInitContext(): # get sharding spec dist_spec = getattr(tensor, 'dist_spec', None) pg = getattr(tensor, 'pg', None) + comp_spec = getattr(tensor, 'comp_spec', None) # convert the tensor from meta to materialized one if tensor.is_meta: @@ -224,14 +225,15 @@ class LazyInitContext(): else: tensor = ColoTensor.from_torch_tensor(tensor) - # apply sharding - if dist_spec: - tensor = tensor.redistribute(dist_spec=dist_spec, pg=pg) - # override the original tensor with torch.no_grad(): setattr(module, name, tensor) + # apply sharding + if dist_spec: + tensor.process_group = pg + tensor.set_tensor_spec(dist_spec, comp_spec) + _init_recursively(model) return model diff --git a/tests/test_fx/test_complete_workflow.py b/tests/test_fx/test_complete_workflow.py new file mode 100644 index 000000000..b17f2cdb6 --- /dev/null +++ b/tests/test_fx/test_complete_workflow.py @@ -0,0 +1,77 @@ +import colossalai +import torch +import torch.nn as nn +import pytest +import torch.multiprocessing as mp +import torch.distributed as dist +from colossalai.testing import rerun_if_address_is_in_use +from functools import partial +from colossalai.fx import ColoTracer +from colossalai.utils.model.lazy_init_context import LazyInitContext +from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass +from colossalai.utils import free_port +from colossalai.tensor import ProcessGroup + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.dropout = torch.nn.Dropout(0) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.linear1(x) + x = self.dropout(x) + x = self.relu(x) + x = self.linear2(x) + return x + + +def run_workflow(world_size): + # initailization + with LazyInitContext() as ctx: + model = MLP(16) + + # tracing + tracer = ColoTracer() + graph = tracer.trace(model) + gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) + + # annotate + annotated_gm = transformer_mlp_pass(gm, process_group=ProcessGroup()) + annotated_gm.recompile() + + # materialization and sharding + ctx.lazy_init_parameters(annotated_gm) + + # # check sharding + assert list(model.linear1.weight.shape) == [16 // world_size, 16] + assert list(model.linear1.bias.shape) == [16 // world_size] + assert list(model.linear2.weight.shape) == [16, 16 // world_size] + + # test forward to make sure that IR transform will produce the same results + # like how ColoTensor would do it normally + data = torch.rand(4, 16) + non_fx_out = model(data) + fx_out = annotated_gm(data) + assert torch.equal(non_fx_out, fx_out) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_workflow(world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_complete_workflow(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_complete_workflow(2) diff --git a/tests/test_fx/test_transform_mlp_pass.py b/tests/test_fx/test_transform_mlp_pass.py deleted file mode 100644 index 202c8ce0e..000000000 --- a/tests/test_fx/test_transform_mlp_pass.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import torch.nn as nn -import pytest -import colossalai -from colossalai.fx import ColoTracer -from colossalai.fx.passes.shard_1d_pass import transform_mlp_pass -CONFIG = dict(parallel=dict(tensor=dict(size=2, mode='1d'))) - -class MLP(torch.nn.Module): - - def __init__(self, dim: int): - super().__init__() - self.linear1 = torch.nn.Linear(dim, dim) - self.linear2 = torch.nn.Linear(dim, dim) - self.linear3 = torch.nn.Linear(dim, dim) - self.linear4 = torch.nn.Linear(dim, dim) - self.dropout = torch.nn.Dropout() - self.relu = torch.nn.ReLU() - - def forward(self, x): - x = self.relu(self.linear1(x)) - x = self.dropout(self.relu(self.linear2(x))) - x = self.linear3(x) - x = torch.nn.functional.relu(self.linear4(x)) - return x - -def test_out_acc(): - model = MLP(16).cuda() - model.eval() - input_tensor = torch.rand(2, 16).cuda() - output = model(input_tensor) - tracer = ColoTracer() - graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")}) - gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) - splitted_gm = transform_mlp_pass(gm) - new_output = splitted_gm(input_tensor) - assert output.equal(new_output) - -def test_linear_acc(): - input_tensor = torch.rand(2, 16).cuda() - model = MLP(16).cuda() - tracer = ColoTracer() - graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")}) - gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) - splitted_gm = transform_mlp_pass(gm) - col_shard = True - for node in splitted_gm.graph.nodes: - if node.op == "call_module" and isinstance(node.graph.owning_module.get_submodule(node.target), torch.nn.Linear): - target_module = node.graph.owning_module.get_submodule(node.target) - dim = 0 if col_shard else -1 - assert target_module.weight.fx_attr == (dim, "SHARD", "TP", "col_needs_many_outputs") - col_shard = not col_shard - -if __name__ == "__main__": - torch.manual_seed(1) - torch.cuda.manual_seed(1) - # colossalai.launch_from_torch(config=CONFIG) - test_out_acc() - test_linear_acc()