[Hotfix] Fix OPT gradient checkpointing forward

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Edenzzzz
2024-07-03 14:57:57 +08:00
committed by GitHub
parent ea94c07b95
commit eb24fcd914

View File

@@ -221,7 +221,7 @@ class OPTPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training:
layer_outputs = self._gradient_checkpointing_func(
layer_outputs = self.decoder._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_attention_mask,