[shardformer] Support the T5ForTokenClassification model (#5816)

* t5 token, still pytest fail

* Resolve T5 Pytest Failure

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Guangyao Zhang
2024-06-27 16:40:38 +08:00
committed by GitHub
parent 5dfbcd7746
commit d9d5e7ea1f
5 changed files with 166 additions and 11 deletions

View File

@@ -40,6 +40,14 @@ def data_gen_for_t5_model():
return data
def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen_for_encoder_only()
data["labels"] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data
# output transform function
output_transform_fn = lambda x: x
@@ -47,6 +55,7 @@ output_transform_fn = lambda x: x
loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean()
loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean()
loss_fn_for_conditional_generation = lambda x: x["loss"]
loss_fn_for_token_classification = lambda x: x["loss"]
# define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)
@@ -79,3 +88,11 @@ model_zoo.register(
loss_fn=loss_fn_for_encoder_only,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_t5_for_token_classification",
model_fn=lambda: transformers.T5ForTokenClassification(config),
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_token_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)