mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 16:00:49 +00:00
[FX] refactor experimental tracer and adapt it with hf models (#3157)
* pass gpt trace and meta_prop * pass t5 trace and meta_prop * [FX] refactor experimental tracer and adapt it with hf models * pass all mainstream model zoo * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * skip tests * fix CI * using packaging version * polish
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
|
||||
|
||||
from .registry import model_zoo
|
||||
|
||||
__all__ = ['model_zoo']
|
||||
|
@@ -17,6 +17,14 @@ def data_gen():
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def seq_classification_data_gen():
|
||||
# batch sizes should be 1 if no padding token is defined.
|
||||
input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
|
||||
@@ -44,6 +52,6 @@ model_zoo.register(name='transformers_gpt_for_token_classification',
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_sequence_classification',
|
||||
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
|
||||
data_gen_fn=data_gen,
|
||||
data_gen_fn=seq_classification_data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
Reference in New Issue
Block a user