[fx] add a symbolic_trace api. (#1812)

* [fx] add a symbolic_trace api.

* [fx] fix import errors.
This commit is contained in:
Super Daniel
2022-11-08 13:59:20 +08:00
committed by GitHub
parent 350ccc0481
commit 441d584e4a
15 changed files with 90 additions and 73 deletions

View File

@@ -3,24 +3,19 @@ from numpy import isin
from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten
from colossalai.fx import ColoTracer
from colossalai.fx import symbolic_trace
def trace_model_and_compare_output(model, data_gen):
# must turn on eval mode to ensure the output is consistent
model.eval()
# make sure that the model is traceable
tracer = ColoTracer()
try:
kwargs = data_gen()
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
gm = symbolic_trace(model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# run forward
inputs = data_gen()

View File

@@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGTH = 16

View File

@@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGTH = 16

View File

@@ -1,10 +1,9 @@
import pytest
import torch
from torch.fx import GraphModule
from utils import trace_model_and_compare_output
import transformers
from colossalai.fx import ColoTracer
from hf_tracer_utils import trace_model_and_compare_output
from colossalai.fx import symbolic_trace
try:
import diffusers
@@ -32,11 +31,7 @@ def test_vae():
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)
tracer = ColoTracer()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model)
model.eval()
gm.eval()
@@ -98,11 +93,7 @@ def test_unet():
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)
tracer = ColoTracer()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
gm = symbolic_trace(model)
model.eval()
gm.eval()

View File

@@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGTH = 16

View File

@@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGTH = 16

View File

@@ -1,7 +1,7 @@
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output
from hf_tracer_utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGTH = 16