[FX] refactor experimental tracer and adapt it with hf models (#3157)

* pass gpt trace and meta_prop

* pass t5 trace and meta_prop

* [FX] refactor experimental tracer and adapt it with hf models

* pass all mainstream model zoo

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* skip tests

* fix CI

* using packaging version

* polish
This commit is contained in:
YuliangLiu0306
2023-03-22 10:40:33 +08:00
committed by GitHub
parent b429529365
commit f57d34958b
28 changed files with 1014 additions and 863 deletions

View File

@@ -1,5 +1,4 @@
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
from .registry import model_zoo
__all__ = ['model_zoo']

View File

@@ -17,6 +17,14 @@ def data_gen():
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
def seq_classification_data_gen():
# batch sizes should be 1 if no padding token is defined.
input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
output_transform_fn = lambda x: x
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
@@ -44,6 +52,6 @@ model_zoo.register(name='transformers_gpt_for_token_classification',
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_sequence_classification',
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
data_gen_fn=data_gen,
data_gen_fn=seq_classification_data_gen,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))