This commit is contained in:
wangbluo 2025-04-28 18:17:12 +08:00
parent 5d167f2148
commit 885210dc27

View File

@ -1,17 +1,9 @@
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -215,7 +207,6 @@ class FalconPipelineForwards:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
past_key_values = None
@ -319,7 +310,7 @@ class FalconPipelineForwards:
hidden_states = outputs[0]
if use_cache is True:
next_decoder_cache = outputs[1]
outputs[1]
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)