mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
[format] applied code formatting on changed files in pull request 4908 (#4918)
Co-authored-by: github-actions <github-actions@github.com>
This commit is contained in:
parent
4f68b3f10c
commit
a41cf88e9b
@ -6,25 +6,20 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
|
||||||
|
from flash_attn.ops.rms_norm import rms_norm
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaRMSNorm,
|
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
LlamaModel,
|
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
|
LlamaModel,
|
||||||
|
LlamaRMSNorm,
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
|
||||||
from flash_attn.flash_attn_interface import (
|
|
||||||
flash_attn_func,
|
|
||||||
flash_attn_varlen_kvpacked_func,
|
|
||||||
)
|
|
||||||
from flash_attn.ops.rms_norm import rms_norm
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
@ -65,7 +60,7 @@ def attention_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""
|
"""
|
||||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
||||||
|
Loading…
Reference in New Issue
Block a user