mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
[shardformer] Align bert value (#3907)
* add bert align test, fix dist loss bug * forward and backward align * add ignore index * add shardformer CI * add gather_output optional for user in shardconfig * update readme with optional gather_ouput * add dist crossentropy loss test, remove unused files * remove unused file * remove unused file * rename the file * polish code
This commit is contained in:
@@ -65,6 +65,8 @@ class ModelSharder(object):
|
||||
BertForMaskedLM.forward -> BertForMaskedLM_.forward
|
||||
"""
|
||||
inject_policy = self.policy.inject_policy()
|
||||
if inject_policy is None:
|
||||
return
|
||||
|
||||
if inject_policy is None:
|
||||
return
|
||||
@@ -148,7 +150,7 @@ class ModelSharder(object):
|
||||
n_cast = policy_layer.n_cast
|
||||
reversed = policy_layer.reversed
|
||||
if policy_layer.__class__.__name__ == "Col_Layer":
|
||||
gather_output = policy_layer.gather_output
|
||||
gather_output = policy_layer.gather_output and self.shard_config.gather_output
|
||||
|
||||
if weight_attr is not None:
|
||||
if hasattr_(org_layer, weight_attr):
|
||||
|
Reference in New Issue
Block a user