mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
[fx] fixed tracing with apex-based T5 model (#1252)
* [fx] fixed tracing with apex-based T5 model * polish code * polish code
This commit is contained in:
parent
7531c6271f
commit
4a09fc0947
@ -7,6 +7,7 @@ tracer.py:
|
|||||||
import enum
|
import enum
|
||||||
import inspect
|
import inspect
|
||||||
import functools
|
import functools
|
||||||
|
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -181,7 +182,16 @@ class ColoTracer(Tracer):
|
|||||||
|
|
||||||
def call_module(self, m, forward, args, kwargs):
|
def call_module(self, m, forward, args, kwargs):
|
||||||
self.orig_forward = forward
|
self.orig_forward = forward
|
||||||
return super().call_module(m, forward, args, kwargs)
|
module_qualified_name = self.path_of_module(m)
|
||||||
|
|
||||||
|
# a leaf module is the torch.nn.Module subclasses starting with `torch.nn`
|
||||||
|
# which means customized modules are not leaf module by default
|
||||||
|
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
|
||||||
|
# we should treat it as leaf module as well
|
||||||
|
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
|
||||||
|
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
|
||||||
|
else:
|
||||||
|
return forward(*args, **kwargs)
|
||||||
|
|
||||||
def proxy(self, node) -> Proxy:
|
def proxy(self, node) -> Proxy:
|
||||||
"""
|
"""
|
||||||
|
@ -1,8 +1,18 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import transformers
|
import transformers
|
||||||
import torch
|
import torch
|
||||||
|
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||||
from utils import trace_model_and_compare_output
|
from utils import trace_model_and_compare_output
|
||||||
|
|
||||||
|
try:
|
||||||
|
import apex
|
||||||
|
|
||||||
|
@meta_patched_module.register(apex.normalization.FusedRMSNorm)
|
||||||
|
def apex_fused_layernorm(self, input):
|
||||||
|
return torch.empty(input.shape, device='meta')
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
BATCH_SIZE = 1
|
BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user