[shardformer] Support the T5ForTokenClassification model (#5816)

* t5 token, still pytest fail

* Resolve T5 Pytest Failure

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Guangyao Zhang
2024-06-27 16:40:38 +08:00
committed by GitHub
parent 5dfbcd7746
commit d9d5e7ea1f
5 changed files with 166 additions and 11 deletions

View File

@@ -68,6 +68,9 @@ _POLICY_LIST = {
file_name="t5", class_name="T5ForConditionalGenerationPolicy"
),
"transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
"transformers.models.t5.modeling_t5.T5ForTokenClassification": PolicyLocation(
file_name="t5", class_name="T5ForTokenClassificationPolicy"
),
# GPT2
"transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation(