From ca56b93d8352cc493722626b9a44a8ad3d9f2b18 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 May 2024 07:07:07 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/opt.py | 10 ++++++---- colossalai/shardformer/policies/opt.py | 14 +++++--------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 227042480..1cde61914 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -21,7 +21,9 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig + from ..layer import cross_entropy_1d + logger = logging.get_logger(__name__) @@ -351,7 +353,7 @@ class OPTPipelineForwards: loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) - + if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -987,8 +989,8 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, ) - #loss_fct = CrossEntropyLoss() - #loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + # loss_fct = CrossEntropyLoss() + # loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] @@ -1002,4 +1004,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): attentions=outputs.attentions, ) - return forward \ No newline at end of file + return forward diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index bb094d25a..524d2b8cd 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -21,9 +21,9 @@ from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.opt import ( OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, + get_lm_forward_with_dist_cross_entropy, get_opt_decoder_forward_for_flash_attention, get_opt_flash_attention_forward, - get_lm_forward_with_dist_cross_entropy ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -270,21 +270,17 @@ class OPTForCausalLMPolicy(OPTPolicy): suffix="lm_head", target_module=VocabParallelLMHead1D, kwargs=dict( - gather_output=not self.shard_config.parallel_output, - make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + gather_output=not self.shard_config.parallel_output, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, ), ), policy=policy, target_key=OPTForCausalLM, ) if self.shard_config.parallel_output: - method_replacement = { - "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) - } + method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} self.append_or_create_method_replacement( - description=method_replacement, - policy=policy, - target_key=OPTForCausalLM + description=method_replacement, policy=policy, target_key=OPTForCausalLM ) else: self.append_or_create_submodule_replacement(