[shardformer] Align bert value (#3907)

* add bert align test, fix dist loss bug

* forward and backward align

* add ignore index

* add shardformer CI

* add gather_output optional for user in shardconfig

* update readme with optional gather_ouput

* add dist crossentropy loss test, remove unused files

* remove unused file

* remove unused file

* rename the file

* polish code
This commit is contained in:
FoolPlayer
2023-06-09 14:36:54 +08:00
committed by Frank Lee
parent 79f8d5d54b
commit f1cb5ac6bf
11 changed files with 174 additions and 197 deletions

View File

@@ -65,6 +65,8 @@ class ModelSharder(object):
BertForMaskedLM.forward -> BertForMaskedLM_.forward
"""
inject_policy = self.policy.inject_policy()
if inject_policy is None:
return
if inject_policy is None:
return
@@ -148,7 +150,7 @@ class ModelSharder(object):
n_cast = policy_layer.n_cast
reversed = policy_layer.reversed
if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output
gather_output = policy_layer.gather_output and self.shard_config.gather_output
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):