[fx]add split module pass and unit test from pipeline passes (#1242)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx]add split module pass and unit test from pipeline passes

* fix MNASNet bug

* polish
This commit is contained in:
YuliangLiu0306
2022-07-12 13:45:01 +08:00
committed by GitHub
parent 762905da68
commit 30b4fc0eb0
11 changed files with 702 additions and 3 deletions

View File

@@ -0,0 +1,69 @@
import torch
from torch.fx import symbolic_trace
from torch.fx import GraphModule
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
import inspect
import random
import numpy as np
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
def split_model_and_compare_output(model, data_gen):
model.eval()
# generate input sample
kwargs = data_gen()
# get origin output and rng state
cpu_rng_state = torch.get_rng_state()
output = model(**kwargs)
# tracing model
tracer = ColoTracer()
try:
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
# get split model
model_part0 = list(split_model.children())[0]
model_part1 = list(split_model.children())[1]
# set rng state and compute output of split model
torch.set_rng_state(cpu_rng_state)
output_part0 = model_part0(**kwargs)
sig = inspect.signature(model_part1.forward)
if isinstance(output_part0, torch.Tensor):
output_part1 = model_part1(output_part0)
else:
if len(output_part0) > len(sig.parameters):
output_part0 = output_part0[:len(sig.parameters)]
output_part1 = model_part1(*output_part0)
# get output tensor from HFOutput datastructure
if 'logits' in output:
output_to_compare = output['logits']
elif 'prediction_logits' in output:
output_to_compare = output['prediction_logits']
else:
output_to_compare = output['last_hidden_state']
# compare output
if isinstance(output_part1, torch.Tensor):
assert output_to_compare.equal(output_part1)
elif isinstance(output_part1, (tuple, list)):
assert output_to_compare.equal(output_part1[0])
else:
assert False

View File

@@ -0,0 +1,38 @@
import transformers
import torch
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGHT = 16
def test_single_sentence_albert():
MODEL_LIST = [
transformers.AlbertModel,
transformers.AlbertForPreTraining,
transformers.AlbertForMaskedLM,
transformers.AlbertForSequenceClassification,
transformers.AlbertForTokenClassification,
]
config = transformers.AlbertConfig(vocab_size=100,
embedding_size=128,
hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args
for model_cls in MODEL_LIST:
model = model_cls(config=config)
split_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_single_sentence_albert()

View File

@@ -0,0 +1,38 @@
import transformers
import torch
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGHT = 16
def test_single_sentence_bert():
MODEL_LIST = [
transformers.BertModel,
transformers.BertForPreTraining,
transformers.BertLMHeadModel,
transformers.BertForMaskedLM,
transformers.BertForSequenceClassification,
transformers.BertForTokenClassification,
]
config = transformers.BertConfig(vocab_size=100,
hidden_size=128,
num_hidden_layers=4,
num_attention_heads=4,
intermediate_size=256)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args
for model_cls in MODEL_LIST:
model = model_cls(config=config)
split_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_single_sentence_bert()

View File

@@ -0,0 +1,34 @@
import transformers
import torch
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 64
SEQ_LENGHT = 16
NUM_EPOCHS = 2
NUM_CHUNKS = 1
def test_gpt():
MODEL_LIST = [
transformers.GPT2Model,
transformers.GPT2LMHeadModel,
transformers.GPT2DoubleHeadsModel,
transformers.GPT2ForTokenClassification,
# transformers.GPT2ForSequenceClassification, # not supported yet
]
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
split_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_gpt()

View File

@@ -0,0 +1,30 @@
import pytest
import transformers
import torch
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
def test_opt():
MODEL_LIST = [
transformers.OPTModel,
transformers.OPTForCausalLM,
]
config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
split_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_opt()

View File

@@ -0,0 +1,43 @@
import pytest
import transformers
import torch
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('tracing failed')
def test_t5():
MODEL_LIST = [
transformers.T5Model,
transformers.T5ForConditionalGeneration,
transformers.T5EncoderModel,
]
config = transformers.T5Config(d_model=128, num_layers=2)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
return kwargs
def data_gen_for_encoder_only():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
if isinstance(model, transformers.T5EncoderModel):
data_gen_func = data_gen_for_encoder_only
else:
data_gen_func = data_gen
split_model_and_compare_output(model, data_gen_func)
if __name__ == '__main__':
test_t5()

View File

@@ -0,0 +1,51 @@
import torch
import pytest
try:
import timm.models as tm
except:
pass
from timm_utils import split_model_and_compare_output
@pytest.mark.skip('skip as timm is required')
def test_timm_models_without_control_flow():
MODEL_LIST = [
tm.resnest.resnest50d,
tm.beit.beit_base_patch16_224,
tm.cait.cait_s24_224,
tm.convmixer.convmixer_768_32,
tm.efficientnet.efficientnetv2_m,
tm.resmlp_12_224,
tm.vision_transformer.vit_base_patch16_224,
tm.deit_base_distilled_patch16_224,
]
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
model = model_cls()
split_model_and_compare_output(model, data)
@pytest.mark.skip('skip as timm is required')
def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True
MODEL_LIST_WITH_CONTROL_FLOW = [
tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100,
tm.swin_transformer.swin_base_patch4_window7_224
]
data = torch.rand(2, 3, 224, 224)
meta_args = {'x': data.to('meta')}
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
model = model_cls()
split_model_and_compare_output(model, data, meta_args)
if __name__ == '__main__':
test_timm_models_without_control_flow()
test_timm_models_with_control_flow()

View File

@@ -0,0 +1,51 @@
import torch
from torch.fx import symbolic_trace
from torch.fx import GraphModule
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
import inspect
import random
import numpy as np
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
torch.backends.cudnn.deterministic = True
def split_model_and_compare_output(model, data, meta_args=None):
model.eval()
# get origin output and rng state
cpu_rng_state = torch.get_rng_state()
output = model(data)
# tracing model
tracer = ColoTracer()
try:
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
# get split model
model_part0 = list(split_model.children())[0]
model_part1 = list(split_model.children())[1]
# set rng state and compute output of split model
torch.set_rng_state(cpu_rng_state)
output_part0 = model_part0(data)
sig = inspect.signature(model_part1.forward)
if isinstance(output_part0, torch.Tensor):
output_part1 = model_part1(output_part0)
else:
if len(output_part0) > len(sig.parameters):
output_part0 = output_part0[:len(sig.parameters)]
output_part1 = model_part1(*output_part0)
assert output.equal(output_part1)

View File

@@ -0,0 +1,62 @@
import torch
try:
import torchvision.models as tm
except:
pass
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from torch.fx import GraphModule
import random
import numpy as np
import inspect
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
torch.backends.cudnn.deterministic = True
@pytest.mark.skip('skip as torchvision is required')
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5
]
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
model = model_cls()
model.eval()
cpu_rng_state = torch.get_rng_state()
output = model(data)
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
# get split model
model_part0 = list(split_model.children())[0]
model_part1 = list(split_model.children())[1]
# set rng state and compute output of split model
torch.set_rng_state(cpu_rng_state)
output_part0 = model_part0(data)
sig = inspect.signature(model_part1.forward)
if isinstance(output_part0, torch.Tensor):
output_part1 = model_part1(output_part0)
else:
if len(output_part0) > len(sig.parameters):
output_part0 = output_part0[:len(sig.parameters)]
output_part1 = model_part1(*output_part0)
assert output.equal(output_part1)
if __name__ == '__main__':
test_torchvision_models()