mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[shardformer] Align bert value (#3907)
* add bert align test, fix dist loss bug * forward and backward align * add ignore index * add shardformer CI * add gather_output optional for user in shardconfig * update readme with optional gather_ouput * add dist crossentropy loss test, remove unused files * remove unused file * remove unused file * rename the file * polish code
This commit is contained in:
@@ -141,7 +141,7 @@ class BertPolicy(Policy):
|
||||
weight="decoder.weight",
|
||||
bias="decoder.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
# gather_output=True,
|
||||
gather_output=True,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -155,7 +155,8 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Tuple[nn.Module, nn.Module]:
|
||||
return (BertForMaskedLM, BertForMaskedLM_)
|
||||
# return (BertForMaskedLM, BertForMaskedLM_)
|
||||
return None
|
||||
|
||||
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
Reference in New Issue
Block a user