[fx]add split module pass and unit test from pipeline passes (#1242)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx]add split module pass and unit test from pipeline passes

* fix MNASNet bug

* polish
This commit is contained in:
YuliangLiu0306
2022-07-12 13:45:01 +08:00
committed by GitHub
parent 762905da68
commit 30b4fc0eb0
11 changed files with 702 additions and 3 deletions

View File

@@ -0,0 +1,62 @@
import torch
try:
import torchvision.models as tm
except:
pass
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from torch.fx import GraphModule
import random
import numpy as np
import inspect
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
torch.backends.cudnn.deterministic = True
@pytest.mark.skip('skip as torchvision is required')
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5
]
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
model = model_cls()
model.eval()
cpu_rng_state = torch.get_rng_state()
output = model(data)
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# apply transform passes
annotated_model = balanced_split_pass(gm, 2)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
# get split model
model_part0 = list(split_model.children())[0]
model_part1 = list(split_model.children())[1]
# set rng state and compute output of split model
torch.set_rng_state(cpu_rng_state)
output_part0 = model_part0(data)
sig = inspect.signature(model_part1.forward)
if isinstance(output_part0, torch.Tensor):
output_part1 = model_part1(output_part0)
else:
if len(output_part0) > len(sig.parameters):
output_part0 = output_part0[:len(sig.parameters)]
output_part1 = model_part1(*output_part0)
assert output.equal(output_part1)
if __name__ == '__main__':
test_torchvision_models()