[shardformer] add util functions for shardformer tests/fix sync_shared_param (#4366)

* add util functions for shardformer tests & rewrite gpt2 test

* fix shared_params & embedding/merging

* fix precision
This commit is contained in:
Baizhou Zhang
2023-08-03 17:50:15 +08:00
committed by Hongxin Liu
parent 5c6f183192
commit b1feeced8e
4 changed files with 189 additions and 113 deletions

View File

@@ -72,7 +72,9 @@ config = transformers.GPT2Config(n_layer=2,
embd_pdrop=0,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0)
hidden_dropout=0,
problem_type="single_label_classification",
pad_token_id=50256)
config_for_token_classification = copy.deepcopy(config)
config_for_token_classification.num_labels = 2