mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[fx] add a symbolic_trace api. (#1812)
* [fx] add a symbolic_trace api. * [fx] fix import errors.
This commit is contained in:
@@ -3,24 +3,19 @@ from numpy import isin
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_model_and_compare_output(model, data_gen):
|
||||
# must turn on eval mode to ensure the output is consistent
|
||||
model.eval()
|
||||
|
||||
# make sure that the model is traceable
|
||||
tracer = ColoTracer()
|
||||
|
||||
try:
|
||||
kwargs = data_gen()
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# run forward
|
||||
inputs = data_gen()
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 16
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 16
|
||||
|
@@ -1,10 +1,9 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from utils import trace_model_and_compare_output
|
||||
|
||||
import transformers
|
||||
from colossalai.fx import ColoTracer
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
try:
|
||||
import diffusers
|
||||
@@ -32,11 +31,7 @@ def test_vae():
|
||||
model = model_cls()
|
||||
sample = torch.zeros(LATENTS_SHAPE)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(root=model)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
@@ -98,11 +93,7 @@ def test_unet():
|
||||
model = model_cls()
|
||||
sample = torch.zeros(LATENTS_SHAPE)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(root=model)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 16
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 16
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from utils import trace_model_and_compare_output
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 16
|
||||
|
@@ -1,12 +1,11 @@
|
||||
import pytest
|
||||
import timm.models as tm
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_and_compare(model_cls, tracer, data, meta_args=None):
|
||||
def trace_and_compare(model_cls, data, meta_args=None):
|
||||
# trace
|
||||
model = model_cls()
|
||||
|
||||
@@ -15,9 +14,7 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
|
||||
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
|
||||
model.eval()
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
|
||||
# run forward
|
||||
with torch.no_grad():
|
||||
@@ -49,11 +46,10 @@ def test_timm_models_without_control_flow():
|
||||
tm.deit_base_distilled_patch16_224,
|
||||
]
|
||||
|
||||
tracer = ColoTracer()
|
||||
data = torch.rand(2, 3, 224, 224)
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
trace_and_compare(model_cls, tracer, data)
|
||||
trace_and_compare(model_cls, data)
|
||||
|
||||
|
||||
def test_timm_models_with_control_flow():
|
||||
@@ -64,13 +60,12 @@ def test_timm_models_with_control_flow():
|
||||
tm.swin_transformer.swin_base_patch4_window7_224
|
||||
]
|
||||
|
||||
tracer = ColoTracer()
|
||||
data = torch.rand(2, 3, 224, 224)
|
||||
|
||||
meta_args = {'x': data.to('meta')}
|
||||
|
||||
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
|
||||
trace_and_compare(model_cls, tracer, data, meta_args)
|
||||
trace_and_compare(model_cls, data, meta_args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,20 +1,16 @@
|
||||
import torch
|
||||
from torch.fx import GraphModule, Tracer
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
|
||||
data = data_gen()
|
||||
concrete_args = data if need_concrete else {}
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
|
||||
tracer = ColoTracer()
|
||||
|
||||
model.eval()
|
||||
|
||||
graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args)
|
||||
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(**data)
|
||||
|
@@ -1,9 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.fx.tracer import meta_patch
|
||||
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
try:
|
||||
from torchrec.models import deepfm
|
||||
@@ -14,8 +12,6 @@ try:
|
||||
except ImportError:
|
||||
NOT_TORCHREC = True
|
||||
|
||||
from torch.fx import GraphModule
|
||||
|
||||
BATCH = 2
|
||||
SHAPE = 10
|
||||
|
||||
@@ -43,9 +39,6 @@ def test_torchrec_deepfm_models():
|
||||
# Dense Features
|
||||
features = torch.rand((BATCH, SHAPE))
|
||||
|
||||
# Tracer
|
||||
tracer = ColoTracer()
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
# Initializing model
|
||||
if model_cls == deepfm.DenseArch:
|
||||
@@ -60,9 +53,7 @@ def test_torchrec_deepfm_models():
|
||||
model = model_cls(ebc)
|
||||
|
||||
# Setup GraphModule
|
||||
graph = tracer.trace(model)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
try:
|
||||
from torchrec.models import dlrm
|
||||
@@ -12,7 +12,6 @@ except ImportError:
|
||||
NOT_TORCHREC = True
|
||||
|
||||
import pytest
|
||||
from torch.fx import GraphModule
|
||||
|
||||
BATCH = 2
|
||||
SHAPE = 10
|
||||
@@ -51,8 +50,6 @@ def test_torchrec_dlrm_models():
|
||||
|
||||
# Sparse Features
|
||||
sparse_features = torch.rand((BATCH, len(keys), SHAPE))
|
||||
# Tracer
|
||||
tracer = ColoTracer()
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
# Initializing model
|
||||
@@ -77,12 +74,9 @@ def test_torchrec_dlrm_models():
|
||||
# Setup GraphModule
|
||||
if model_cls == dlrm.InteractionV2Arch:
|
||||
concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features}
|
||||
graph = tracer.trace(model, concrete_args=concrete_args)
|
||||
gm = symbolic_trace(model, concrete_args=concrete_args)
|
||||
else:
|
||||
graph = tracer.trace(model)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
@@ -2,8 +2,8 @@ import torch
|
||||
import torchvision
|
||||
import torchvision.models as tm
|
||||
from packaging import version
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
|
||||
|
||||
def test_torchvision_models():
|
||||
@@ -20,7 +20,6 @@ def test_torchvision_models():
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
tracer = ColoTracer()
|
||||
data = torch.rand(2, 3, 224, 224)
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
@@ -30,10 +29,7 @@ def test_torchvision_models():
|
||||
else:
|
||||
model = model_cls()
|
||||
|
||||
graph = tracer.trace(root=model)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
gm = symbolic_trace(model)
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
Reference in New Issue
Block a user