From 30b4fc0eb03f7ab60c6b14d6dca1ba78363bbf3e Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 12 Jul 2022 13:45:01 +0800 Subject: [PATCH] [fx]add split module pass and unit test from pipeline passes (#1242) * [CLI] add CLI launcher * Revert "[CLI] add CLI launcher" This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c. * [fx]add split module pass and unit test from pipeline passes * fix MNASNet bug * polish --- .../fx/passes/adding_split_node_pass.py | 12 +- colossalai/fx/passes/split_module.py | 277 ++++++++++++++++++ .../test_pipeline/test_hf_model/hf_utils.py | 69 +++++ .../test_hf_model/test_albert.py | 38 +++ .../test_pipeline/test_hf_model/test_bert.py | 38 +++ .../test_pipeline/test_hf_model/test_gpt.py | 34 +++ .../test_pipeline/test_hf_model/test_opt.py | 30 ++ .../test_pipeline/test_hf_model/test_t5.py | 43 +++ .../test_timm_model/test_timm.py | 51 ++++ .../test_timm_model/timm_utils.py | 51 ++++ .../test_torchvision/test_torchvision.py | 62 ++++ 11 files changed, 702 insertions(+), 3 deletions(-) create mode 100644 colossalai/fx/passes/split_module.py create mode 100644 tests/test_fx/test_pipeline/test_hf_model/hf_utils.py create mode 100644 tests/test_fx/test_pipeline/test_hf_model/test_albert.py create mode 100644 tests/test_fx/test_pipeline/test_hf_model/test_bert.py create mode 100644 tests/test_fx/test_pipeline/test_hf_model/test_gpt.py create mode 100644 tests/test_fx/test_pipeline/test_hf_model/test_opt.py create mode 100644 tests/test_fx/test_pipeline/test_hf_model/test_t5.py create mode 100644 tests/test_fx/test_pipeline/test_timm_model/test_timm.py create mode 100644 tests/test_fx/test_pipeline/test_timm_model/timm_utils.py create mode 100644 tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py index 91005fe6b..9c77590ff 100644 --- a/colossalai/fx/passes/adding_split_node_pass.py +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -2,7 +2,7 @@ import torch from torch.fx import symbolic_trace from torch.fx.node import Node -from torch.fx.passes.split_module import split_module +from colossalai.fx.passes.split_module import split_module def pipe_split(): @@ -26,8 +26,14 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): if accumulate_param_amount >= params_per_partition: accumulate_param_amount = 0 pp_size -= 1 - with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + # If the next node is output node, we will insert split annotation before + # node to make sure there is at least one node in last partition. + if node.next.op == 'output': + with mod_graph.inserting_before(node): + split_node = mod_graph.create_node('call_function', pipe_split) + else: + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) gm.recompile() return gm diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py new file mode 100644 index 000000000..7b5cbb3cd --- /dev/null +++ b/colossalai/fx/passes/split_module.py @@ -0,0 +1,277 @@ +import torch +from torch.fx.graph_module import GraphModule +from typing import Callable, List, Dict, Any, Optional +from torch.fx._compatibility import compatibility +import inspect + + +@compatibility(is_backward_compatible=True) +class Partition: + """ + Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py + """ + + def __init__(self, name: str): + self.name: str = name + self.node_names: List[str] = [] + self.inputs: Dict[str, None] = {} + self.outputs: Dict[str, None] = {} + self.partitions_dependent_on: Dict[str, None] = {} + self.partition_dependents: Dict[str, None] = {} + self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + self.environment: Dict[torch.fx.node.Node, torch.fx.node.Node] = {} + self.targets: Dict[str, Any] = {} + + def __repr__(self) -> str: + return f"name: {self.name},\n" \ + f" nodes: {self.node_names},\n" \ + f" inputs: {self.inputs},\n" \ + f" outputs: {self.outputs},\n" \ + f" partitions depenent on: {self.partitions_dependent_on},\n" \ + f" parition dependents: {self.partition_dependents}" + + +# Creates subgraphs out of main graph +@compatibility(is_backward_compatible=True) +def split_module( + m: GraphModule, + root_m: torch.nn.Module, + split_callback: Callable[[torch.fx.node.Node], int], +): + """ + Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py + Creates subgraphs out of main graph + + Args: + m (GraphModule): Graph module to split + root_m (torch.nn.Module): root nn module. Not currently used. Included + because the root nn module is usually transformed via + torch.fx._symbolic_trace.symbolic_trace (see example below) + split_callback (Callable[[torch.fx.node.Node], int]): Callable function + that maps a given Node instance to a numeric partition identifier. + split_module will use this function as the policy for which operations + appear in which partitions in the output Module. + + Returns: + GraphModule: the module after split. + + Example: + + This is a sample setup: + + import torch + from torch.fx.symbolic_trace import symbolic_trace + from torch.fx.graph_module import GraphModule + from torch.fx.node import Node + from colossalai.fx.passes.split_module import split_module + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x, y): + z = self.linear(x + self.param).clamp(min=0.0, max=1.0) + w = self.linear(y).clamp(min=0.0, max=1.0) + return z + w + + # symbolically trace model + my_module = MyModule() + my_module_traced = symbolic_trace(my_module) + + # random mod partitioning + partition_counter = 0 + NPARTITIONS = 3 + + def mod_partition(node: Node): + global partition_counter + partition = partition_counter % NPARTITIONS + partition_counter = (partition_counter + 1) % NPARTITIONS + return partition + + # split module in module with submodules + module_with_submodules = split_module( + my_module_traced, my_module, mod_partition + ) + + Output looks like this. Original graph is broken into partitions + + > print(module_with_submodules) + GraphModule( + (submod_0): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_1): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_2): GraphModule() + ) + + def forward(self, x, y): + param = self.param + submod_0 = self.submod_0(x, param, y); x = param = y = None + getitem = submod_0[0] + getitem_1 = submod_0[1]; submod_0 = None + submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None + getitem_2 = submod_1[0] + getitem_3 = submod_1[1]; submod_1 = None + submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None + return submod_2 + + Output of split module is the same as output of input traced module. + This is an example within a test setting: + + > orig_out = my_module_traced(x, y) + > submodules_out = module_with_submodules(x, y) + > self.assertEqual(orig_out, submodules_out) + True + """ + partitions: Dict[str, Partition] = {} + orig_nodes: Dict[str, torch.fx.node.Node] = {} + + def record_cross_partition_use(def_node: torch.fx.node.Node, + use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def_partition_name = getattr(def_node, '_fx_partition', None) + use_partition_name = getattr(use_node, '_fx_partition', None) + if def_partition_name != use_partition_name: + if def_partition_name is not None: + def_partition = partitions[def_partition_name] + def_partition.outputs.setdefault(def_node.name) + if use_partition_name is not None: + def_partition.partition_dependents.setdefault(use_partition_name) + + if use_partition_name is not None: + use_partition = partitions[use_partition_name] + use_partition.inputs.setdefault(def_node.name) + if def_partition_name is not None: + use_partition.partitions_dependent_on.setdefault(def_partition_name) + + # split nodes into parititons + for node in m.graph.nodes: + orig_nodes[node.name] = node + + if node.op in ["placeholder"]: + continue + if node.op == 'output': + torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) + continue + partition_name = str(split_callback(node)) + + # add node to partitions + partition = partitions.get(partition_name) + if partition is None: + partitions[partition_name] = partition = Partition(partition_name) + + partition.node_names.append(node.name) + node._fx_partition = partition_name + + torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) + torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 + + # find partitions with no dependencies + root_partitions: List[str] = [] + for partition_name, partition in partitions.items(): + if not len(partition.partitions_dependent_on): + root_partitions.append(partition_name) + + # check partitions for circular dependencies and create topological partition ordering + sorted_partitions: List[str] = [] + while root_partitions: + root_partition = root_partitions.pop() + sorted_partitions.append(root_partition) + for dependent in partitions[root_partition].partition_dependents: + partitions[dependent].partitions_dependent_on.pop(root_partition) + if not partitions[dependent].partitions_dependent_on: + root_partitions.append(dependent) + if len(sorted_partitions) != len(partitions): + raise RuntimeError("cycle exists between partitions!") + + # add placeholders to parititons + for partition_name in sorted_partitions: + partition = partitions[partition_name] + for input in partition.inputs: + placeholder = partition.graph.placeholder(input) + placeholder.meta = orig_nodes[input].meta.copy() + partition.environment[orig_nodes[input]] = placeholder + + # Transform nodes and collect targets for partition's submodule + for node in m.graph.nodes: + if hasattr(node, '_fx_partition'): + partition = partitions[node._fx_partition] + + # swap out old graph nodes in kw/args with references to new nodes in this submodule + environment = partition.environment + gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) + gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) + + if node.op not in ['call_module', 'get_attr']: + target = node.target + else: + target_atoms = node.target.split('.') + target_attr = m + for atom in target_atoms: + if not hasattr(target_attr, atom): + raise RuntimeError(f'Operator target {node.target} not found!') + target_attr = getattr(target_attr, atom) + # target = target_atoms[-1] + target = '_'.join(target_atoms) + partition.targets[target] = target_attr + + assert isinstance(gathered_args, tuple) + assert isinstance(gathered_kwargs, dict) + new_node = partition.graph.create_node(op=node.op, + target=target, + args=gathered_args, + kwargs=gathered_kwargs) + new_node.meta = node.meta.copy() + partition.environment[node] = new_node + + # Set up values to construct base module + base_mod_env: Dict[str, torch.fx.node.Node] = {} + base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} + for node in m.graph.nodes: + if node.op == 'placeholder': + default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty + base_mod_env[node.name] = base_mod_graph.placeholder(node.name, + type_expr=node.type, + default_value=default_value) + base_mod_env[node.name].meta = node.meta.copy() + + # Do some things iterating over the partitions in topological order again: + # 1) Finish off submodule Graphs by setting corresponding outputs + # 2) Construct GraphModules for each submodule + # 3) Construct the base graph by emitting calls to those submodules in + # topological order + + for partition_name in sorted_partitions: + partition = partitions[partition_name] + + # Set correct output values + output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) + output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] + partition.graph.output(output_vals) + + # Construct GraphModule for this partition + submod_name = f'submod_{partition_name}' + base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, + partition.graph) # noqa: B950 + + # Emit call in base graph to this submodule + output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) + if len(partition.outputs) > 1: + # Unpack multiple return values from submodule + output_val_proxy = torch.fx.proxy.Proxy(output_val) + for i, output_name in enumerate(partition.outputs): + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + else: + if not partition.outputs: + continue + base_mod_env[list(partition.outputs)[0]] = output_val + + for node in m.graph.nodes: + if node.op == 'output': + base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 + + return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) diff --git a/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py new file mode 100644 index 000000000..3afc6c97e --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py @@ -0,0 +1,69 @@ +import torch +from torch.fx import symbolic_trace +from torch.fx import GraphModule +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +import inspect +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) + + +def split_model_and_compare_output(model, data_gen): + model.eval() + + # generate input sample + kwargs = data_gen() + + # get origin output and rng state + cpu_rng_state = torch.get_rng_state() + output = model(**kwargs) + + # tracing model + tracer = ColoTracer() + try: + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + + # get split model + model_part0 = list(split_model.children())[0] + model_part1 = list(split_model.children())[1] + + # set rng state and compute output of split model + torch.set_rng_state(cpu_rng_state) + output_part0 = model_part0(**kwargs) + sig = inspect.signature(model_part1.forward) + if isinstance(output_part0, torch.Tensor): + output_part1 = model_part1(output_part0) + else: + if len(output_part0) > len(sig.parameters): + output_part0 = output_part0[:len(sig.parameters)] + output_part1 = model_part1(*output_part0) + + # get output tensor from HFOutput datastructure + if 'logits' in output: + output_to_compare = output['logits'] + elif 'prediction_logits' in output: + output_to_compare = output['prediction_logits'] + else: + output_to_compare = output['last_hidden_state'] + + # compare output + if isinstance(output_part1, torch.Tensor): + assert output_to_compare.equal(output_part1) + elif isinstance(output_part1, (tuple, list)): + assert output_to_compare.equal(output_part1[0]) + else: + assert False diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py new file mode 100644 index 000000000..8349ff52b --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py @@ -0,0 +1,38 @@ +import transformers +import torch +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 2 +SEQ_LENGHT = 16 + + +def test_single_sentence_albert(): + MODEL_LIST = [ + transformers.AlbertModel, + transformers.AlbertForPreTraining, + transformers.AlbertForMaskedLM, + transformers.AlbertForSequenceClassification, + transformers.AlbertForTokenClassification, + ] + + config = transformers.AlbertConfig(vocab_size=100, + embedding_size=128, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return meta_args + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + split_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_single_sentence_albert() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py new file mode 100644 index 000000000..36fbfcfb3 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py @@ -0,0 +1,38 @@ +import transformers +import torch +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 2 +SEQ_LENGHT = 16 + + +def test_single_sentence_bert(): + MODEL_LIST = [ + transformers.BertModel, + transformers.BertForPreTraining, + transformers.BertLMHeadModel, + transformers.BertForMaskedLM, + transformers.BertForSequenceClassification, + transformers.BertForTokenClassification, + ] + + config = transformers.BertConfig(vocab_size=100, + hidden_size=128, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=256) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return meta_args + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + split_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_single_sentence_bert() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py new file mode 100644 index 000000000..4a6636f49 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py @@ -0,0 +1,34 @@ +import transformers +import torch +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 64 +SEQ_LENGHT = 16 +NUM_EPOCHS = 2 +NUM_CHUNKS = 1 + + +def test_gpt(): + MODEL_LIST = [ + transformers.GPT2Model, + transformers.GPT2LMHeadModel, + transformers.GPT2DoubleHeadsModel, + transformers.GPT2ForTokenClassification, + # transformers.GPT2ForSequenceClassification, # not supported yet + ] + config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + split_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_gpt() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py new file mode 100644 index 000000000..a55ea54fe --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py @@ -0,0 +1,30 @@ +import pytest +import transformers +import torch +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + + +def test_opt(): + MODEL_LIST = [ + transformers.OPTModel, + transformers.OPTForCausalLM, + ] + + config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + split_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_opt() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py new file mode 100644 index 000000000..d78883c3d --- /dev/null +++ b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py @@ -0,0 +1,43 @@ +import pytest +import transformers +import torch +from hf_utils import split_model_and_compare_output + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('tracing failed') +def test_t5(): + MODEL_LIST = [ + transformers.T5Model, + transformers.T5ForConditionalGeneration, + transformers.T5EncoderModel, + ] + + config = transformers.T5Config(d_model=128, num_layers=2) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + return kwargs + + def data_gen_for_encoder_only(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + + if isinstance(model, transformers.T5EncoderModel): + data_gen_func = data_gen_for_encoder_only + else: + data_gen_func = data_gen + + split_model_and_compare_output(model, data_gen_func) + + +if __name__ == '__main__': + test_t5() diff --git a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py new file mode 100644 index 000000000..bf11cb30a --- /dev/null +++ b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py @@ -0,0 +1,51 @@ +import torch +import pytest +try: + import timm.models as tm +except: + pass +from timm_utils import split_model_and_compare_output + + +@pytest.mark.skip('skip as timm is required') +def test_timm_models_without_control_flow(): + + MODEL_LIST = [ + tm.resnest.resnest50d, + tm.beit.beit_base_patch16_224, + tm.cait.cait_s24_224, + tm.convmixer.convmixer_768_32, + tm.efficientnet.efficientnetv2_m, + tm.resmlp_12_224, + tm.vision_transformer.vit_base_patch16_224, + tm.deit_base_distilled_patch16_224, + ] + + data = torch.rand(2, 3, 224, 224) + + for model_cls in MODEL_LIST: + model = model_cls() + split_model_and_compare_output(model, data) + + +@pytest.mark.skip('skip as timm is required') +def test_timm_models_with_control_flow(): + torch.backends.cudnn.deterministic = True + + MODEL_LIST_WITH_CONTROL_FLOW = [ + tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100, + tm.swin_transformer.swin_base_patch4_window7_224 + ] + + data = torch.rand(2, 3, 224, 224) + + meta_args = {'x': data.to('meta')} + + for model_cls in MODEL_LIST_WITH_CONTROL_FLOW: + model = model_cls() + split_model_and_compare_output(model, data, meta_args) + + +if __name__ == '__main__': + test_timm_models_without_control_flow() + test_timm_models_with_control_flow() diff --git a/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py new file mode 100644 index 000000000..aa870e5c7 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py @@ -0,0 +1,51 @@ +import torch +from torch.fx import symbolic_trace +from torch.fx import GraphModule +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +import inspect +import random +import numpy as np + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) +torch.backends.cudnn.deterministic = True + + +def split_model_and_compare_output(model, data, meta_args=None): + model.eval() + + # get origin output and rng state + cpu_rng_state = torch.get_rng_state() + output = model(data) + + # tracing model + tracer = ColoTracer() + try: + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + + # get split model + model_part0 = list(split_model.children())[0] + model_part1 = list(split_model.children())[1] + + # set rng state and compute output of split model + torch.set_rng_state(cpu_rng_state) + output_part0 = model_part0(data) + sig = inspect.signature(model_part1.forward) + if isinstance(output_part0, torch.Tensor): + output_part1 = model_part1(output_part0) + else: + if len(output_part0) > len(sig.parameters): + output_part0 = output_part0[:len(sig.parameters)] + output_part1 = model_part1(*output_part0) + assert output.equal(output_part1) diff --git a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py new file mode 100644 index 000000000..dab485063 --- /dev/null +++ b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py @@ -0,0 +1,62 @@ +import torch +try: + import torchvision.models as tm +except: + pass +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from torch.fx import GraphModule + +import random +import numpy as np +import inspect + +MANUAL_SEED = 0 +random.seed(MANUAL_SEED) +np.random.seed(MANUAL_SEED) +torch.manual_seed(MANUAL_SEED) +torch.backends.cudnn.deterministic = True + + +@pytest.mark.skip('skip as torchvision is required') +def test_torchvision_models(): + MODEL_LIST = [ + tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, + tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5 + ] + + tracer = ColoTracer() + data = torch.rand(2, 3, 224, 224) + + for model_cls in MODEL_LIST: + model = model_cls() + model.eval() + cpu_rng_state = torch.get_rng_state() + output = model(data) + graph = tracer.trace(root=model) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # apply transform passes + annotated_model = balanced_split_pass(gm, 2) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + + # get split model + model_part0 = list(split_model.children())[0] + model_part1 = list(split_model.children())[1] + + # set rng state and compute output of split model + torch.set_rng_state(cpu_rng_state) + output_part0 = model_part0(data) + sig = inspect.signature(model_part1.forward) + if isinstance(output_part0, torch.Tensor): + output_part1 = model_part1(output_part0) + else: + if len(output_part0) > len(sig.parameters): + output_part0 = output_part0[:len(sig.parameters)] + output_part1 = model_part1(*output_part0) + assert output.equal(output_part1) + + +if __name__ == '__main__': + test_torchvision_models()