[shardformer] Align bert value (#3907)

* add bert align test, fix dist loss bug

* forward and backward align

* add ignore index

* add shardformer CI

* add gather_output optional for user in shardconfig

* update readme with optional gather_ouput

* add dist crossentropy loss test, remove unused files

* remove unused file

* remove unused file

* rename the file

* polish code
This commit is contained in:
FoolPlayer
2023-06-09 14:36:54 +08:00
committed by Frank Lee
parent 79f8d5d54b
commit f1cb5ac6bf
11 changed files with 174 additions and 197 deletions

View File

@@ -141,7 +141,7 @@ class BertPolicy(Policy):
weight="decoder.weight",
bias="decoder.bias",
replace_layer=col_nn.Linear1D_Col,
# gather_output=True,
gather_output=True,
)
]
@@ -155,7 +155,8 @@ class BertForMaskedLMPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
return (BertForMaskedLM, BertForMaskedLM_)
# return (BertForMaskedLM, BertForMaskedLM_)
return None
class BertForSequenceClassificationPolicy(BertPolicy):