[fx] fixed tracing with apex-based T5 model (#1252)

* [fx] fixed tracing with apex-based T5 model

* polish code

* polish code
This commit is contained in:
Frank Lee
2022-07-12 15:19:25 +08:00
committed by GitHub
parent 7531c6271f
commit 4a09fc0947
2 changed files with 21 additions and 1 deletions

View File

@@ -7,6 +7,7 @@ tracer.py:
import enum
import inspect
import functools
from colossalai.fx.tracer.meta_patch import meta_patched_module
import torch
import torch.nn as nn
from torch import Tensor
@@ -181,7 +182,16 @@ class ColoTracer(Tracer):
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
return super().call_module(m, forward, args, kwargs)
module_qualified_name = self.path_of_module(m)
# a leaf module is the torch.nn.Module subclasses starting with `torch.nn`
# which means customized modules are not leaf module by default
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
else:
return forward(*args, **kwargs)
def proxy(self, node) -> Proxy:
"""