[format] applied code formatting on changed files in pull request 5510 (#5517)

Co-authored-by: github-actions <github-actions@github.com>
This commit is contained in:
github-actions[bot]
2024-03-27 11:21:03 +08:00
committed by GitHub
parent 19e1a5cf16
commit e6707a6e8d
4 changed files with 14 additions and 7 deletions

View File

@@ -265,12 +265,18 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output})
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs={"gather_output": not self.shard_config.parallel_output},
)
],
)
}
if self.shard_config.parallel_output:
new_item[LlamaForCausalLM].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
policy.update(new_item)
if self.pipeline_stage_manager: