[fx] fixed tracing with apex-based T5 model (#1252)

* [fx] fixed tracing with apex-based T5 model

* polish code

* polish code
This commit is contained in:
Frank Lee
2022-07-12 15:19:25 +08:00
committed by GitHub
parent 7531c6271f
commit 4a09fc0947
2 changed files with 21 additions and 1 deletions

View File

@@ -1,8 +1,18 @@
import pytest
import transformers
import torch
from colossalai.fx.tracer.meta_patch import meta_patched_module
from utils import trace_model_and_compare_output
try:
import apex
@meta_patched_module.register(apex.normalization.FusedRMSNorm)
def apex_fused_layernorm(self, input):
return torch.empty(input.shape, device='meta')
except ImportError:
pass
BATCH_SIZE = 1
SEQ_LENGHT = 16