[fx] add a symbolic_trace api. (#1812)

* [fx] add a symbolic_trace api.

* [fx] fix import errors.
This commit is contained in:
Super Daniel
2022-11-08 13:59:20 +08:00
committed by GitHub
parent 350ccc0481
commit 441d584e4a
15 changed files with 90 additions and 73 deletions

View File

@@ -3,24 +3,19 @@ from numpy import isin
from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten
from colossalai.fx import ColoTracer
from colossalai.fx import symbolic_trace
def trace_model_and_compare_output(model, data_gen):
# must turn on eval mode to ensure the output is consistent
model.eval()
# make sure that the model is traceable
tracer = ColoTracer()
try:
kwargs = data_gen()
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
gm = symbolic_trace(model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# run forward
inputs = data_gen()

View File

@@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGTH = 16

View File

@@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGTH = 16

View File

@@ -1,10 +1,9 @@
import pytest
import torch
from torch.fx import GraphModule
from utils import trace_model_and_compare_output
import transformers
from colossalai.fx import ColoTracer
from hf_tracer_utils import trace_model_and_compare_output
from colossalai.fx import symbolic_trace
try:
import diffusers
@@ -32,11 +31,7 @@ def test_vae():
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)
tracer = ColoTracer()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model)
model.eval()
gm.eval()
@@ -98,11 +93,7 @@ def test_unet():
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)
tracer = ColoTracer()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model)
model.eval()
gm.eval()

View File

@@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGTH = 16

View File

@@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGTH = 16

View File

@@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGTH = 16

View File

@@ -1,12 +1,11 @@
import pytest
import timm.models as tm
import torch
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
from colossalai.fx import symbolic_trace
def trace_and_compare(model_cls, tracer, data, meta_args=None):
def trace_and_compare(model_cls, data, meta_args=None):
# trace
model = model_cls()
@@ -15,9 +14,7 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
model.eval()
graph = tracer.trace(root=model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model, meta_args=meta_args)
# run forward
with torch.no_grad():
@@ -49,11 +46,10 @@ def test_timm_models_without_control_flow():
tm.deit_base_distilled_patch16_224,
]
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
trace_and_compare(model_cls, tracer, data)
trace_and_compare(model_cls, data)
def test_timm_models_with_control_flow():
@@ -64,13 +60,12 @@ def test_timm_models_with_control_flow():
tm.swin_transformer.swin_base_patch4_window7_224
]
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
meta_args = {'x': data.to('meta')}
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
trace_and_compare(model_cls, tracer, data, meta_args)
trace_and_compare(model_cls, data, meta_args)
if __name__ == '__main__':

View File

@@ -1,20 +1,16 @@
import torch
from torch.fx import GraphModule, Tracer
from colossalai.fx import ColoTracer
from colossalai.fx import symbolic_trace
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
data = data_gen()
concrete_args = data if need_concrete else {}
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
tracer = ColoTracer()
model.eval()
graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args)
with torch.no_grad():
non_fx_out = model(**data)

View File

@@ -1,9 +1,7 @@
import pytest
import torch
from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx import symbolic_trace
try:
from torchrec.models import deepfm
@@ -14,8 +12,6 @@ try:
except ImportError:
NOT_TORCHREC = True
from torch.fx import GraphModule
BATCH = 2
SHAPE = 10
@@ -43,9 +39,6 @@ def test_torchrec_deepfm_models():
# Dense Features
features = torch.rand((BATCH, SHAPE))
# Tracer
tracer = ColoTracer()
for model_cls in MODEL_LIST:
# Initializing model
if model_cls == deepfm.DenseArch:
@@ -60,9 +53,7 @@ def test_torchrec_deepfm_models():
model = model_cls(ebc)
# Setup GraphModule
graph = tracer.trace(model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model)
model.eval()
gm.eval()

View File

@@ -1,6 +1,6 @@
import torch
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx import symbolic_trace
try:
from torchrec.models import dlrm
@@ -12,7 +12,6 @@ except ImportError:
NOT_TORCHREC = True
import pytest
from torch.fx import GraphModule
BATCH = 2
SHAPE = 10
@@ -51,8 +50,6 @@ def test_torchrec_dlrm_models():
# Sparse Features
sparse_features = torch.rand((BATCH, len(keys), SHAPE))
# Tracer
tracer = ColoTracer()
for model_cls in MODEL_LIST:
# Initializing model
@@ -77,12 +74,9 @@ def test_torchrec_dlrm_models():
# Setup GraphModule
if model_cls == dlrm.InteractionV2Arch:
concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features}
graph = tracer.trace(model, concrete_args=concrete_args)
gm = symbolic_trace(model, concrete_args=concrete_args)
else:
graph = tracer.trace(model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model)
model.eval()
gm.eval()

View File

@@ -2,8 +2,8 @@ import torch
import torchvision
import torchvision.models as tm
from packaging import version
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from colossalai.fx import symbolic_trace
def test_torchvision_models():
@@ -20,7 +20,6 @@ def test_torchvision_models():
torch.backends.cudnn.deterministic = True
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
@@ -30,10 +29,7 @@ def test_torchvision_models():
else:
model = model_cls()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model)
model.eval()
gm.eval()