mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
fix
This commit is contained in:
parent
5d167f2148
commit
885210dc27
@ -1,17 +1,9 @@
|
|||||||
import math
|
|
||||||
import warnings
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from transformers.modeling_attn_mask_utils import (
|
|
||||||
AttentionMaskConverter,
|
|
||||||
_prepare_4d_causal_attention_mask,
|
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
@ -215,7 +207,6 @@ class FalconPipelineForwards:
|
|||||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
|
|
||||||
logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
|
logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
|
|
||||||
@ -319,7 +310,7 @@ class FalconPipelineForwards:
|
|||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
next_decoder_cache = outputs[1]
|
outputs[1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
Loading…
Reference in New Issue
Block a user