mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
[fx] fixed tracing with apex-based T5 model (#1252)
* [fx] fixed tracing with apex-based T5 model * polish code * polish code
This commit is contained in:
@@ -1,8 +1,18 @@
|
||||
import pytest
|
||||
import transformers
|
||||
import torch
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||
from utils import trace_model_and_compare_output
|
||||
|
||||
try:
|
||||
import apex
|
||||
|
||||
@meta_patched_module.register(apex.normalization.FusedRMSNorm)
|
||||
def apex_fused_layernorm(self, input):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
Reference in New Issue
Block a user