This commit is contained in:
wangbluo 2025-04-28 18:17:12 +08:00
parent 5d167f2148
commit 885210dc27

View File

@ -1,17 +1,9 @@
import math
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
@ -215,7 +207,6 @@ class FalconPipelineForwards:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False use_cache = False
logger.warning_once("past_key_values is not supported for pipeline models at the moment.") logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
past_key_values = None past_key_values = None
@ -319,7 +310,7 @@ class FalconPipelineForwards:
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:
next_decoder_cache = outputs[1] outputs[1]
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)