mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[fx]add split module pass and unit test from pipeline passes (#1242)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx]add split module pass and unit test from pipeline passes
* fix MNASNet bug
* polish
This commit is contained in:
69
tests/test_fx/test_pipeline/test_hf_model/hf_utils.py
Normal file
69
tests/test_fx/test_pipeline/test_hf_model/hf_utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
import inspect
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
np.random.seed(MANUAL_SEED)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
|
||||
|
||||
def split_model_and_compare_output(model, data_gen):
|
||||
model.eval()
|
||||
|
||||
# generate input sample
|
||||
kwargs = data_gen()
|
||||
|
||||
# get origin output and rng state
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
output = model(**kwargs)
|
||||
|
||||
# tracing model
|
||||
tracer = ColoTracer()
|
||||
try:
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# apply transform passes
|
||||
annotated_model = balanced_split_pass(gm, 2)
|
||||
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||
|
||||
# get split model
|
||||
model_part0 = list(split_model.children())[0]
|
||||
model_part1 = list(split_model.children())[1]
|
||||
|
||||
# set rng state and compute output of split model
|
||||
torch.set_rng_state(cpu_rng_state)
|
||||
output_part0 = model_part0(**kwargs)
|
||||
sig = inspect.signature(model_part1.forward)
|
||||
if isinstance(output_part0, torch.Tensor):
|
||||
output_part1 = model_part1(output_part0)
|
||||
else:
|
||||
if len(output_part0) > len(sig.parameters):
|
||||
output_part0 = output_part0[:len(sig.parameters)]
|
||||
output_part1 = model_part1(*output_part0)
|
||||
|
||||
# get output tensor from HFOutput datastructure
|
||||
if 'logits' in output:
|
||||
output_to_compare = output['logits']
|
||||
elif 'prediction_logits' in output:
|
||||
output_to_compare = output['prediction_logits']
|
||||
else:
|
||||
output_to_compare = output['last_hidden_state']
|
||||
|
||||
# compare output
|
||||
if isinstance(output_part1, torch.Tensor):
|
||||
assert output_to_compare.equal(output_part1)
|
||||
elif isinstance(output_part1, (tuple, list)):
|
||||
assert output_to_compare.equal(output_part1[0])
|
||||
else:
|
||||
assert False
|
38
tests/test_fx/test_pipeline/test_hf_model/test_albert.py
Normal file
38
tests/test_fx/test_pipeline/test_hf_model/test_albert.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import transformers
|
||||
import torch
|
||||
from hf_utils import split_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
def test_single_sentence_albert():
|
||||
MODEL_LIST = [
|
||||
transformers.AlbertModel,
|
||||
transformers.AlbertForPreTraining,
|
||||
transformers.AlbertForMaskedLM,
|
||||
transformers.AlbertForSequenceClassification,
|
||||
transformers.AlbertForTokenClassification,
|
||||
]
|
||||
|
||||
config = transformers.AlbertConfig(vocab_size=100,
|
||||
embedding_size=128,
|
||||
hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
return meta_args
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls(config=config)
|
||||
split_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_single_sentence_albert()
|
38
tests/test_fx/test_pipeline/test_hf_model/test_bert.py
Normal file
38
tests/test_fx/test_pipeline/test_hf_model/test_bert.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import transformers
|
||||
import torch
|
||||
from hf_utils import split_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
def test_single_sentence_bert():
|
||||
MODEL_LIST = [
|
||||
transformers.BertModel,
|
||||
transformers.BertForPreTraining,
|
||||
transformers.BertLMHeadModel,
|
||||
transformers.BertForMaskedLM,
|
||||
transformers.BertForSequenceClassification,
|
||||
transformers.BertForTokenClassification,
|
||||
]
|
||||
|
||||
config = transformers.BertConfig(vocab_size=100,
|
||||
hidden_size=128,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
return meta_args
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls(config=config)
|
||||
split_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_single_sentence_bert()
|
34
tests/test_fx/test_pipeline/test_hf_model/test_gpt.py
Normal file
34
tests/test_fx/test_pipeline/test_hf_model/test_gpt.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import transformers
|
||||
import torch
|
||||
from hf_utils import split_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 64
|
||||
SEQ_LENGHT = 16
|
||||
NUM_EPOCHS = 2
|
||||
NUM_CHUNKS = 1
|
||||
|
||||
|
||||
def test_gpt():
|
||||
MODEL_LIST = [
|
||||
transformers.GPT2Model,
|
||||
transformers.GPT2LMHeadModel,
|
||||
transformers.GPT2DoubleHeadsModel,
|
||||
transformers.GPT2ForTokenClassification,
|
||||
# transformers.GPT2ForSequenceClassification, # not supported yet
|
||||
]
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls(config=config)
|
||||
split_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt()
|
30
tests/test_fx/test_pipeline/test_hf_model/test_opt.py
Normal file
30
tests/test_fx/test_pipeline/test_hf_model/test_opt.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import pytest
|
||||
import transformers
|
||||
import torch
|
||||
from hf_utils import split_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
def test_opt():
|
||||
MODEL_LIST = [
|
||||
transformers.OPTModel,
|
||||
transformers.OPTForCausalLM,
|
||||
]
|
||||
|
||||
config = transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls(config=config)
|
||||
split_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_opt()
|
43
tests/test_fx/test_pipeline/test_hf_model/test_t5.py
Normal file
43
tests/test_fx/test_pipeline/test_hf_model/test_t5.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
import transformers
|
||||
import torch
|
||||
from hf_utils import split_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('tracing failed')
|
||||
def test_t5():
|
||||
MODEL_LIST = [
|
||||
transformers.T5Model,
|
||||
transformers.T5ForConditionalGeneration,
|
||||
transformers.T5EncoderModel,
|
||||
]
|
||||
|
||||
config = transformers.T5Config(d_model=128, num_layers=2)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
return kwargs
|
||||
|
||||
def data_gen_for_encoder_only():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids)
|
||||
return kwargs
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls(config=config)
|
||||
|
||||
if isinstance(model, transformers.T5EncoderModel):
|
||||
data_gen_func = data_gen_for_encoder_only
|
||||
else:
|
||||
data_gen_func = data_gen
|
||||
|
||||
split_model_and_compare_output(model, data_gen_func)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_t5()
|
51
tests/test_fx/test_pipeline/test_timm_model/test_timm.py
Normal file
51
tests/test_fx/test_pipeline/test_timm_model/test_timm.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import pytest
|
||||
try:
|
||||
import timm.models as tm
|
||||
except:
|
||||
pass
|
||||
from timm_utils import split_model_and_compare_output
|
||||
|
||||
|
||||
@pytest.mark.skip('skip as timm is required')
|
||||
def test_timm_models_without_control_flow():
|
||||
|
||||
MODEL_LIST = [
|
||||
tm.resnest.resnest50d,
|
||||
tm.beit.beit_base_patch16_224,
|
||||
tm.cait.cait_s24_224,
|
||||
tm.convmixer.convmixer_768_32,
|
||||
tm.efficientnet.efficientnetv2_m,
|
||||
tm.resmlp_12_224,
|
||||
tm.vision_transformer.vit_base_patch16_224,
|
||||
tm.deit_base_distilled_patch16_224,
|
||||
]
|
||||
|
||||
data = torch.rand(2, 3, 224, 224)
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls()
|
||||
split_model_and_compare_output(model, data)
|
||||
|
||||
|
||||
@pytest.mark.skip('skip as timm is required')
|
||||
def test_timm_models_with_control_flow():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
MODEL_LIST_WITH_CONTROL_FLOW = [
|
||||
tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100,
|
||||
tm.swin_transformer.swin_base_patch4_window7_224
|
||||
]
|
||||
|
||||
data = torch.rand(2, 3, 224, 224)
|
||||
|
||||
meta_args = {'x': data.to('meta')}
|
||||
|
||||
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
|
||||
model = model_cls()
|
||||
split_model_and_compare_output(model, data, meta_args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_timm_models_without_control_flow()
|
||||
test_timm_models_with_control_flow()
|
51
tests/test_fx/test_pipeline/test_timm_model/timm_utils.py
Normal file
51
tests/test_fx/test_pipeline/test_timm_model/timm_utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
import inspect
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
np.random.seed(MANUAL_SEED)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def split_model_and_compare_output(model, data, meta_args=None):
|
||||
model.eval()
|
||||
|
||||
# get origin output and rng state
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
output = model(data)
|
||||
|
||||
# tracing model
|
||||
tracer = ColoTracer()
|
||||
try:
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# apply transform passes
|
||||
annotated_model = balanced_split_pass(gm, 2)
|
||||
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||
|
||||
# get split model
|
||||
model_part0 = list(split_model.children())[0]
|
||||
model_part1 = list(split_model.children())[1]
|
||||
|
||||
# set rng state and compute output of split model
|
||||
torch.set_rng_state(cpu_rng_state)
|
||||
output_part0 = model_part0(data)
|
||||
sig = inspect.signature(model_part1.forward)
|
||||
if isinstance(output_part0, torch.Tensor):
|
||||
output_part1 = model_part1(output_part0)
|
||||
else:
|
||||
if len(output_part0) > len(sig.parameters):
|
||||
output_part0 = output_part0[:len(sig.parameters)]
|
||||
output_part1 = model_part1(*output_part0)
|
||||
assert output.equal(output_part1)
|
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
try:
|
||||
import torchvision.models as tm
|
||||
except:
|
||||
pass
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from torch.fx import GraphModule
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import inspect
|
||||
|
||||
MANUAL_SEED = 0
|
||||
random.seed(MANUAL_SEED)
|
||||
np.random.seed(MANUAL_SEED)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
@pytest.mark.skip('skip as torchvision is required')
|
||||
def test_torchvision_models():
|
||||
MODEL_LIST = [
|
||||
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
||||
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5
|
||||
]
|
||||
|
||||
tracer = ColoTracer()
|
||||
data = torch.rand(2, 3, 224, 224)
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls()
|
||||
model.eval()
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
output = model(data)
|
||||
graph = tracer.trace(root=model)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# apply transform passes
|
||||
annotated_model = balanced_split_pass(gm, 2)
|
||||
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||
|
||||
# get split model
|
||||
model_part0 = list(split_model.children())[0]
|
||||
model_part1 = list(split_model.children())[1]
|
||||
|
||||
# set rng state and compute output of split model
|
||||
torch.set_rng_state(cpu_rng_state)
|
||||
output_part0 = model_part0(data)
|
||||
sig = inspect.signature(model_part1.forward)
|
||||
if isinstance(output_part0, torch.Tensor):
|
||||
output_part1 = model_part1(output_part0)
|
||||
else:
|
||||
if len(output_part0) > len(sig.parameters):
|
||||
output_part0 = output_part0[:len(sig.parameters)]
|
||||
output_part1 = model_part1(*output_part0)
|
||||
assert output.equal(output_part1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_torchvision_models()
|
Reference in New Issue
Block a user