add parallel_output for the opt model

This commit is contained in:
wangbluo
2024-05-03 08:58:00 +00:00
parent 88f057ce7c
commit 108ddfb795
2 changed files with 174 additions and 4 deletions

View File

@@ -23,6 +23,7 @@ from ..modeling.opt import (
get_jit_fused_opt_decoder_layer_forward,
get_opt_decoder_forward_for_flash_attention,
get_opt_flash_attention_forward,
get_lm_forward_with_dist_cross_entropy
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -269,12 +270,22 @@ class OPTForCausalLMPolicy(OPTPolicy):
suffix="lm_head",
target_module=VocabParallelLMHead1D,
kwargs=dict(
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
gather_output=not self.shard_config.parallel_output,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
),
),
policy=policy,
target_key=OPTForCausalLM,
)
if self.shard_config.parallel_output:
method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
self.append_or_create_method_replacement(
description=method_replacement,
policy=policy,
target_key=OPTForCausalLM
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(