From 4a09fc0947dd61bf98a98ce5363828dac4875629 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 12 Jul 2022 15:19:25 +0800 Subject: [PATCH] [fx] fixed tracing with apex-based T5 model (#1252) * [fx] fixed tracing with apex-based T5 model * polish code * polish code --- colossalai/fx/tracer/tracer.py | 12 +++++++++++- .../test_fx/test_tracer/test_hf_model/test_hf_t5.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index 0398dc89f..33e14d57c 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -7,6 +7,7 @@ tracer.py: import enum import inspect import functools +from colossalai.fx.tracer.meta_patch import meta_patched_module import torch import torch.nn as nn from torch import Tensor @@ -181,7 +182,16 @@ class ColoTracer(Tracer): def call_module(self, m, forward, args, kwargs): self.orig_forward = forward - return super().call_module(m, forward, args, kwargs) + module_qualified_name = self.path_of_module(m) + + # a leaf module is the torch.nn.Module subclasses starting with `torch.nn` + # which means customized modules are not leaf module by default + # if a customized or third-party module like apex.normalization.FusedRMSNorm is patched, + # we should treat it as leaf module as well + if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name): + return self.create_proxy('call_module', module_qualified_name, args, kwargs) + else: + return forward(*args, **kwargs) def proxy(self, node) -> Proxy: """ diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 0bf765174..3605f986d 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -1,8 +1,18 @@ import pytest import transformers import torch +from colossalai.fx.tracer.meta_patch import meta_patched_module from utils import trace_model_and_compare_output +try: + import apex + + @meta_patched_module.register(apex.normalization.FusedRMSNorm) + def apex_fused_layernorm(self, input): + return torch.empty(input.shape, device='meta') +except ImportError: + pass + BATCH_SIZE = 1 SEQ_LENGHT = 16