mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 15:11:20 +00:00
[fx] add a symbolic_trace api. (#1812)
* [fx] add a symbolic_trace api. * [fx] fix import errors.
This commit is contained in:
@@ -3,24 +3,19 @@ from numpy import isin
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_model_and_compare_output(model, data_gen):
|
||||
# must turn on eval mode to ensure the output is consistent
|
||||
model.eval()
|
||||
|
||||
# make sure that the model is traceable
|
||||
tracer = ColoTracer()
|
||||
|
||||
try:
|
||||
kwargs = data_gen()
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# run forward
|
||||
inputs = data_gen()
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 16
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 16
|
||||
|
@@ -1,10 +1,9 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from utils import trace_model_and_compare_output
|
||||
|
||||
import transformers
|
||||
from colossalai.fx import ColoTracer
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
try:
|
||||
import diffusers
|
||||
@@ -32,11 +31,7 @@ def test_vae():
|
||||
model = model_cls()
|
||||
sample = torch.zeros(LATENTS_SHAPE)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(root=model)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
@@ -98,11 +93,7 @@ def test_unet():
|
||||
model = model_cls()
|
||||
sample = torch.zeros(LATENTS_SHAPE)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(root=model)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 16
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 16
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 16
|
||||
|
Reference in New Issue
Block a user