From eb24fcd914f4c38fb82bc082db84d13d50865572 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 3 Jul 2024 14:57:57 +0800 Subject: [PATCH] [Hotfix] Fix OPT gradient checkpointing forward Co-authored-by: Edenzzzz --- colossalai/shardformer/modeling/opt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index f10860fef..b250b4976 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -221,7 +221,7 @@ class OPTPipelineForwards: past_key_value = past_key_values[idx] if past_key_values is not None else None if decoder.gradient_checkpointing and decoder.training: - layer_outputs = self._gradient_checkpointing_func( + layer_outputs = self.decoder._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_attention_mask,