mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 21:49:08 +00:00
[tests] remove T5 test skip decorator (#1271)
This commit is contained in:
parent
de498255b5
commit
01ea68b2e6
@ -2,12 +2,20 @@ import pytest
|
|||||||
import transformers
|
import transformers
|
||||||
import torch
|
import torch
|
||||||
from hf_utils import split_model_and_compare_output
|
from hf_utils import split_model_and_compare_output
|
||||||
|
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||||
|
try:
|
||||||
|
import apex
|
||||||
|
|
||||||
|
@meta_patched_module.register(apex.normalization.FusedRMSNorm)
|
||||||
|
def apex_fused_layernorm(self, input):
|
||||||
|
return torch.empty(input.shape, device='meta')
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
BATCH_SIZE = 1
|
BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip('tracing failed')
|
|
||||||
def test_t5():
|
def test_t5():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
transformers.T5Model,
|
transformers.T5Model,
|
||||||
@ -15,7 +23,7 @@ def test_t5():
|
|||||||
transformers.T5EncoderModel,
|
transformers.T5EncoderModel,
|
||||||
]
|
]
|
||||||
|
|
||||||
config = transformers.T5Config(d_model=128, num_layers=2)
|
config = transformers.T5Config(vocab_size=100, d_model=128, num_layers=2)
|
||||||
|
|
||||||
def data_gen():
|
def data_gen():
|
||||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||||
|
Loading…
Reference in New Issue
Block a user