mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 22:42:15 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
5480b811c5
commit
4eced5cf8a
@ -1,17 +1,9 @@
|
|||||||
import math
|
|
||||||
import warnings
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from transformers.modeling_attn_mask_utils import (
|
|
||||||
AttentionMaskConverter,
|
|
||||||
_prepare_4d_causal_attention_mask,
|
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
@ -159,7 +151,7 @@ def get_tp_falcon_decoder_layer_forward():
|
|||||||
and self.config.num_ln_in_parallel_attn == 1
|
and self.config.num_ln_in_parallel_attn == 1
|
||||||
):
|
):
|
||||||
mlp_layernorm_out = attention_layernorm_out
|
mlp_layernorm_out = attention_layernorm_out
|
||||||
|
|
||||||
outputs = attn_outputs[1:]
|
outputs = attn_outputs[1:]
|
||||||
|
|
||||||
# MLP.
|
# MLP.
|
||||||
@ -215,7 +207,6 @@ class FalconPipelineForwards:
|
|||||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
|
|
||||||
logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
|
logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
|
|
||||||
@ -251,7 +242,7 @@ class FalconPipelineForwards:
|
|||||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||||
alibi = None
|
alibi = None
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
||||||
batch_size, seq_length, _ = hidden_states.shape
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
if self.use_alibi:
|
if self.use_alibi:
|
||||||
mask = (
|
mask = (
|
||||||
@ -262,7 +253,7 @@ class FalconPipelineForwards:
|
|||||||
else attention_mask
|
else attention_mask
|
||||||
)
|
)
|
||||||
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
|
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
cache_position = torch.arange(
|
cache_position = torch.arange(
|
||||||
past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device
|
past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device
|
||||||
@ -280,7 +271,7 @@ class FalconPipelineForwards:
|
|||||||
# attention_probs has shape batch_size x num_heads x N x N
|
# attention_probs has shape batch_size x num_heads x N x N
|
||||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
# create position embeddings to be shared across the decoder layers
|
# create position embeddings to be shared across the decoder layers
|
||||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
@ -319,7 +310,7 @@ class FalconPipelineForwards:
|
|||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
next_decoder_cache = outputs[1]
|
outputs[1]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
@ -332,7 +323,7 @@ class FalconPipelineForwards:
|
|||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
|
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
|
||||||
|
Loading…
Reference in New Issue
Block a user