mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
add parallel_output for the opt model
This commit is contained in:
@@ -23,6 +23,7 @@ from ..modeling.opt import (
|
||||
get_jit_fused_opt_decoder_layer_forward,
|
||||
get_opt_decoder_forward_for_flash_attention,
|
||||
get_opt_flash_attention_forward,
|
||||
get_lm_forward_with_dist_cross_entropy
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
@@ -269,12 +270,22 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
||||
suffix="lm_head",
|
||||
target_module=VocabParallelLMHead1D,
|
||||
kwargs=dict(
|
||||
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
|
||||
gather_output=not self.shard_config.parallel_output,
|
||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM,
|
||||
)
|
||||
if self.shard_config.parallel_output:
|
||||
method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM
|
||||
)
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
|
Reference in New Issue
Block a user