[npu] use extension for op builder (#5172)

* update extension

* update cpu adam

* update is

* add doc for cpu adam

* update kernel

* update commit

* update flash

* update memory efficient

* update flash attn

* update flash attention loader

* update api

* fix

* update doc

* update example time limit

* reverse change

* fix doc

* remove useless kernel

* fix

* not use warning

* update

* update
This commit is contained in:
Xuanlei Zhao 2024-01-08 11:39:16 +08:00 committed by GitHub
parent d6df19bae7
commit dd2c28a323
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1067 additions and 274 deletions

View File

@ -1,7 +1,14 @@
from .cpu_adam_loader import CPUAdamLoader
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
from .extensions.flash_attention import AttnMaskType
from .flash_attention_loader import ColoAttention, FlashAttentionLoader
__all__ = [ __all__ = [
"LayerNorm", "LayerNorm",
"FusedScaleMaskSoftmax", "FusedScaleMaskSoftmax",
"MultiHeadAttention", "MultiHeadAttention",
"CPUAdamLoader",
"FlashAttentionLoader",
"ColoAttention",
"AttnMaskType",
] ]

View File

@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
from typing import Dict, List
from .extensions.base_extension import BaseExtension
class BaseKernelLoader(ABC):
"""
Usage:
kernel_loader = KernelLoader()
kernel = kernel_loader.load()
"""
def __init__(self, extension_map: Dict[str, BaseExtension], supported_device: List[str]):
self._extension_map = extension_map
self._supported_device = supported_device
def run_checks(self):
# run supported device check and other possible checks
pass
@abstractmethod
def fetch_kernel(self):
pass
def load(self):
self.run_checks()
return self.fetch_kernel()

View File

@ -0,0 +1,64 @@
import platform
from collections import OrderedDict
from .base_kernel_loader import BaseKernelLoader
from .extensions.cpu_adam import ArmCPUAdamExtension, X86CPUAdamExtension
class CPUAdamLoader(BaseKernelLoader):
"""
CPU Adam Loader
Usage:
# init
cpu_adam = CPUAdamLoader().load()
cpu_adam_op = cpu_adam.CPUAdamOptimizer(
alpha, beta1, beta2, epsilon, weight_decay, adamw_mode,
)
...
# optim step
cpu_adam_op.step(
step, lr, beta1, beta2, epsilon, weight_decay, bias_correction,
params, grads, exp_avg, exp_avg_sq, loss_scale,
)
Args:
func CPUAdamOptimizer:
alpha (float): learning rate. Default to 1e-3.
beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9.
beta2 (float): coefficients used for computing running averages of its square. Default to 0.99.
epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8.
weight_decay (float): weight decay (L2 penalty). Default to 0.
adamw_mode (bool): whether to use the adamw. Default to True.
func step:
step (int): current step.
lr (float): learning rate.
beta1 (float): coefficients used for computing running averages of gradient.
beta2 (float): coefficients used for computing running averages of its square.
epsilon (float): term added to the denominator to improve numerical stability.
weight_decay (float): weight decay (L2 penalty).
bias_correction (bool): whether to use bias correction.
params (torch.Tensor): parameter.
grads (torch.Tensor): gradient.
exp_avg (torch.Tensor): exp average.
exp_avg_sq (torch.Tensor): exp average square.
loss_scale (float): loss scale value.
"""
def __init__(self):
super().__init__(
extension_map=OrderedDict(
arm=ArmCPUAdamExtension,
x86=X86CPUAdamExtension,
),
supported_device=["cpu"],
)
def fetch_kernel(self):
if platform.machine() == "x86_64":
kernel = self._extension_map["x86"]().fetch()
elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]:
kernel = self._extension_map["arm"]().fetch()
else:
raise Exception("not supported")
return kernel

View File

@ -1,5 +1,4 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .mha.mha import ColoAttention
from .multihead_attention import MultiHeadAttention from .multihead_attention import MultiHeadAttention
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
@ -8,6 +7,5 @@ __all__ = [
"MultiHeadAttention", "MultiHeadAttention",
"FusedScaleMaskSoftmax", "FusedScaleMaskSoftmax",
"ScaledUpperTriangMaskedSoftmax", "ScaledUpperTriangMaskedSoftmax",
"ColoAttention",
"AttnMaskType", "AttnMaskType",
] ]

View File

@ -1,3 +0,0 @@
from .mha import ColoAttention
__all__ = ["ColoAttention"]

View File

@ -1,114 +0,0 @@
import math
from typing import Optional
import torch
from einops import rearrange
from ..scaled_softmax import AttnMaskType
from .flash_attn_2 import HAS_FLASH_ATTN
from .mem_eff_attn import HAS_MEM_EFF_ATTN
from .utils import Repad, SeqLenInfo, Unpad
if HAS_FLASH_ATTN:
from .flash_attn_2 import flash_attention
if HAS_MEM_EFF_ATTN:
from .mem_eff_attn import mem_eff_attention
class ColoAttention(torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
super().__init__()
assert (
embed_dim % num_heads == 0
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN:
raise Exception("flash attention can not support!")
@staticmethod
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
return Unpad.apply(tensor, indices)
@staticmethod
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
return Repad.apply(tensor, indices, batch_size, seq_len)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: Optional[AttnMaskType] = None,
bias: Optional[torch.Tensor] = None,
):
attn = None
if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None:
attn = flash_attention
else:
attn = mem_eff_attention
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
causal = attn_mask_type is not None and attn_mask_type.value > 1
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
# unpad
seq_len_info_q = None
seq_len_info_kv = None
if padded:
# bert style, unpad process
assert (
attn_mask is not None
), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
assert attn_mask.dim() == 2, (
"attention mask is supposed to have shape (batch_size, seq_len), "
+ f"but got {attn_mask.dim()} dimensions."
)
# bert style
if tgt_len == src_len:
seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
if batch_size > 1:
query, key, value = self.unpad(
torch.stack([query, key, value], dim=2), seq_len_info_q.indices
).unbind(dim=1)
else:
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
seq_len_info_kv = seq_len_info_q
else:
seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
if batch_size > 1:
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
dim=1
)
else:
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
out = attn(
query,
key,
value,
seq_len_info_q,
seq_len_info_kv,
dropout_p=self.dropout,
scale=self.scale,
causal=causal,
padded=padded,
)
# repad
if padded:
if batch_size > 1:
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
out = rearrange(out, "b s h d -> b s (h d)")
return out

View File

View File

@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
from typing import Callable
class BaseExtension(ABC):
@abstractmethod
def requires_build(self) -> bool:
pass
@abstractmethod
def build(self) -> None:
pass
@abstractmethod
def load(self) -> Callable:
pass
def fetch(self) -> Callable:
if self.requires_build:
self.build()
return self.load()

View File

@ -0,0 +1,4 @@
from .arm_extension import ArmCPUAdamExtension
from .x86_extension import X86CPUAdamExtension
__all__ = ["ArmCPUAdamExtension", "X86CPUAdamExtension"]

View File

@ -0,0 +1,53 @@
from ..base_extension import BaseExtension
from ..extension_builder import ExtensionBuilder
class ArmCPUAdamExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
self.kernel_builder = ArmCPUAdamBuilder()
self._requires_build = False
@property
def requires_build(self) -> bool:
return self._requires_build
def build(self):
self.kernel_builder.build()
self._requires_build = True
def load(self):
return self.kernel_builder.load()
class ArmCPUAdamBuilder(ExtensionBuilder):
NAME = "arm_cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam"
ext_type = "cpu"
def __init__(self):
super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam_arm.cpp"),
]
return ret
def include_dirs(self):
return [self.csrc_abs_path("includes")]
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-g",
"-Wno-reorder",
"-fopenmp",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
return []

View File

@ -0,0 +1,65 @@
from ..base_extension import BaseExtension
from ..extension_builder import ExtensionBuilder
from ..utils import append_nvcc_threads
class X86CPUAdamExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
self.kernel_builder = X86CPUAdamBuilder()
self._requires_build = False
@property
def requires_build(self) -> bool:
return self._requires_build
def build(self):
self.kernel_builder.build()
self._requires_build = True
def load(self):
return self.kernel_builder.load()
class X86CPUAdamBuilder(ExtensionBuilder):
NAME = "cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam"
def __init__(self):
super().__init__(name=X86CPUAdamBuilder.NAME, prebuilt_import_path=X86CPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam.cpp"),
]
return ret
def include_dirs(self):
return [self.csrc_abs_path("includes"), self.get_cuda_home_include()]
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-lcudart",
"-lcublas",
"-g",
"-Wno-reorder",
"-fopenmp",
"-march=native",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
extra_cuda_flags = [
"-std=c++14",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
]
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)

View File

@ -0,0 +1,243 @@
# This code has been adapted from the DeepSpeed library.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import importlib
import os
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
class ExtensionBuilder(ABC):
"""
Builder is the base class to build extensions for PyTorch.
Args:
name (str): the name of the kernel to be built
prebuilt_import_path (str): the path where the extension is installed during pip install
"""
ext_type: str = "cuda"
def __init__(self, name: str, prebuilt_import_path: str):
self.name = name
self.prebuilt_import_path = prebuilt_import_path
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# we store the op as an attribute to avoid repeated building and loading
self.cached_op_module = None
assert prebuilt_import_path.startswith(
"colossalai._C"
), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}"
def relative_to_abs_path(self, code_path: str) -> str:
"""
This function takes in a path relative to the colossalai root directory and return the absolute path.
"""
op_builder_module_path = Path(__file__).parent
# if we install from source
# the current file path will be op_builder/builder.py
# if we install via pip install colossalai
# the current file path will be colossalai/kernel/op_builder/builder.py
# this is because that the op_builder inside colossalai is a symlink
# this symlink will be replaced with actual files if we install via pypi
# thus we cannot tell the colossalai root directory by checking whether the op_builder
# is a symlink, we can only tell whether it is inside or outside colossalai
if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"):
root_path = op_builder_module_path.parent.parent
elif str(op_builder_module_path).endswith("colossalai/kernel/extensions"):
root_path = op_builder_module_path.parent.parent
else:
root_path = op_builder_module_path.parent.joinpath("colossalai")
code_abs_path = root_path.joinpath(code_path)
return str(code_abs_path)
def get_cuda_home_include(self):
"""
return include path inside the cuda home.
"""
from torch.utils.cpp_extension import CUDA_HOME
if CUDA_HOME is None:
raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
cuda_include = os.path.join(CUDA_HOME, "include")
return cuda_include
def csrc_abs_path(self, path):
return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path)
# functions must be overrided begin
@abstractmethod
def sources_files(self) -> List[str]:
"""
This function should return a list of source files for extensions.
"""
raise NotImplementedError
@abstractmethod
def include_dirs(self) -> List[str]:
"""
This function should return a list of include files for extensions.
"""
@abstractmethod
def cxx_flags(self) -> List[str]:
"""
This function should return a list of cxx compilation flags for extensions.
"""
@abstractmethod
def nvcc_flags(self) -> List[str]:
"""
This function should return a list of nvcc compilation flags for extensions.
"""
# functions must be overrided over
def strip_empty_entries(self, args):
"""
Drop any empty strings from the list of compile and link flags
"""
return [x for x in args if len(x) > 0]
def import_op(self):
"""
This function will import the op module by its string name.
"""
return importlib.import_module(self.prebuilt_import_path)
def check_runtime_build_environment(self):
"""
Check whether the system environment is ready for extension compilation.
"""
try:
from torch.utils.cpp_extension import CUDA_HOME
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
CUDA_HOME = None
if not TORCH_AVAILABLE:
raise ModuleNotFoundError(
"PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions"
)
if CUDA_HOME is None:
raise RuntimeError(
"CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions"
)
# make sure CUDA is available for compilation during
cuda_available = check_cuda_availability()
if not cuda_available:
raise RuntimeError("CUDA is not available on your system as torch.cuda.is_available() returns False.")
# make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not
check_system_pytorch_cuda_match(CUDA_HOME)
def build(self, verbose: Optional[bool] = None):
"""
If the kernel is not built during pip install, it will build the kernel.
If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the
kernel is built during pip install, it can be accessed through `colossalai._C`.
Warning: do not load this kernel repeatedly during model execution as it could slow down the training process.
Args:
verbose (bool, optional): show detailed info. Defaults to True.
"""
if verbose is None:
verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1"
try:
# if the kernel has been pre-built during installation
# we just directly import it
op_module = self.import_op()
if verbose:
print_rank_0(
f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building."
)
except ImportError:
# check environment
if self.ext_type == "cuda":
self.check_runtime_build_environment()
# time the kernel compilation
start_build = time.time()
# construct the build directory
import torch
from torch.utils.cpp_extension import load
torch_version_major = torch.__version__.split(".")[0]
torch_version_minor = torch.__version__.split(".")[1]
torch_cuda_version = torch.version.cuda
home_directory = os.path.expanduser("~")
extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}"
build_directory = os.path.join(home_directory, extension_directory)
Path(build_directory).mkdir(parents=True, exist_ok=True)
if verbose:
print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now")
# load the kernel
op_module = load(
name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_cuda_cflags=self.nvcc_flags(),
extra_ldflags=[],
build_directory=build_directory,
verbose=verbose,
)
build_duration = time.time() - start_build
# log jit compilation time
if verbose:
print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds")
# cache the built/loaded kernel
self.cached_op_module = op_module
def load(self, verbose: Optional[bool] = None):
"""
load the kernel during runtime.
Args:
verbose (bool, optional): show detailed info. Defaults to True.
"""
# if the kernel has be compiled and cached, we directly use it
assert self.cached_op_module is not None, "Please build the kernel first before loading it."
return self.cached_op_module
def builder(self) -> Union["CUDAExtension", "CppExtension"]:
"""
get a CUDAExtension instance used for setup.py
"""
from torch.utils.cpp_extension import CppExtension, CUDAExtension
if self.ext_type == "cpp":
return CppExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
)
return CUDAExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args={
"cxx": self.strip_empty_entries(self.cxx_flags()),
"nvcc": self.strip_empty_entries(self.nvcc_flags()),
},
)

View File

@ -0,0 +1,19 @@
from .cuda_flash_attn_2_extension import HAS_FLASH_ATTN, CudaFlashAttnExtension
from .cuda_memory_efficient_attn_extension import HAS_MEM_EFF_ATTN, CudaMemoryEfficentAttnExtension
from .npu_sdpa_attn_extension import NpuSdpaAttnExtension
from .npu_triangle_attn_extension import HAS_NPU_TRIANGLE_ATTENTION, NpuTriangleAttnExtension
from .utils import AttnMaskType, Repad, SeqLenInfo, Unpad
__all__ = [
"CudaFlashAttnExtension",
"CudaMemoryEfficentAttnExtension",
"NpuSdpaAttnExtension",
"NpuTriangleAttnExtension",
"HAS_FLASH_ATTN",
"HAS_MEM_EFF_ATTN",
"HAS_NPU_TRIANGLE_ATTENTION",
"Unpad",
"AttnMaskType",
"Repad",
"SeqLenInfo",
]

View File

@ -1,10 +1,14 @@
import warnings
from typing import Optional from typing import Optional
import torch import torch
from ..base_extension import BaseExtension
from ..utils import print_rank_0
from .utils import SeqLenInfo
def is_ampere_or_better_gpu(): def is_ampere_or_better_gpu():
# Check Ampere GPUs or newer
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
properties = torch.cuda.get_device_properties(device) properties = torch.cuda.get_device_properties(device)
@ -13,31 +17,28 @@ def is_ampere_or_better_gpu():
return False return False
# "Check Ampere GPUs or newer"
HAS_FLASH_ATTN = False HAS_FLASH_ATTN = False
ERROR_MSG = None
if is_ampere_or_better_gpu(): if is_ampere_or_better_gpu():
HAS_FLASH_ATTN = True
else:
warnings.warn("FlashAttention only supports Ampere GPUs or newer.")
HAS_FLASH_ATTN = False
try: try:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
HAS_FLASH_ATTN = True HAS_FLASH_ATTN = True
except ImportError: except ImportError:
warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention") ERROR_MSG = "ImportError: please install flash_attn from https://github.com/HazyResearch/flash-attention"
HAS_FLASH_ATTN = False else:
ERROR_MSG = "ImportError: FlashAttention only supports Ampere GPUs or newer."
if HAS_FLASH_ATTN: if HAS_FLASH_ATTN:
from .utils import SeqLenInfo
def flash_attention( def flash_attention(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
seq_len_info_q: SeqLenInfo, seq_len_info_q: SeqLenInfo,
seq_len_info_kv: SeqLenInfo, seq_len_info_kv: SeqLenInfo,
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0, dropout_p: float = 0.0,
scale: float = None, scale: float = None,
@ -77,3 +78,23 @@ if HAS_FLASH_ATTN:
else: else:
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
return attn_out return attn_out
class CudaFlashAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self):
return False
def build(self):
pass
def is_available(self):
if HAS_FLASH_ATTN == False:
print_rank_0(ERROR_MSG)
return HAS_FLASH_ATTN
def load(self):
return flash_attention

View File

@ -1,4 +1,10 @@
import warnings from typing import Optional
import torch
from ..base_extension import BaseExtension
from ..utils import print_rank_0
from .utils import SeqLenInfo
HAS_MEM_EFF_ATTN = False HAS_MEM_EFF_ATTN = False
try: try:
@ -12,19 +18,13 @@ try:
HAS_MEM_EFF_ATTN = True HAS_MEM_EFF_ATTN = True
except ImportError: except ImportError:
warnings.warn("please install xformers from https://github.com/facebookresearch/xformers") pass
HAS_MEM_EFF_ATTN = False
if HAS_MEM_EFF_ATTN: if HAS_MEM_EFF_ATTN:
""" """
A general attention module using the flash attention kernels from xformers: A general attention module using the flash attention kernels from xformers:
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
""" """
from typing import Optional
import torch
from .utils import SeqLenInfo
allow_alibi = True allow_alibi = True
for op in MemoryEfficientAttentionCutlassOp: for op in MemoryEfficientAttentionCutlassOp:
@ -36,6 +36,7 @@ if HAS_MEM_EFF_ATTN:
v: torch.Tensor, v: torch.Tensor,
seq_len_info_q: SeqLenInfo, seq_len_info_q: SeqLenInfo,
seq_len_info_kv: SeqLenInfo, seq_len_info_kv: SeqLenInfo,
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0, dropout_p: float = 0.0,
scale: float = None, scale: float = None,
@ -68,3 +69,23 @@ if HAS_MEM_EFF_ATTN:
out = out.squeeze(0) out = out.squeeze(0)
return out return out
class CudaMemoryEfficentAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self) -> bool:
return False
def build(self):
pass
def is_available(self):
if HAS_MEM_EFF_ATTN == False:
print_rank_0("ImportError: please install xformers from https://github.com/facebookresearch/xformers")
return HAS_MEM_EFF_ATTN
def load(self):
return mem_eff_attention

View File

@ -1,16 +1,20 @@
import torch import torch
from einops import rearrange from einops import rearrange
from ..base_extension import BaseExtension
def npu_sdpa_attention( def npu_sdpa_attention(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
attn_mask: torch.Tensor = None, seq_len_info_q=None,
seq_len_info_kv=None,
origin_attn_mask: torch.Tensor = None, origin_attn_mask: torch.Tensor = None,
scale: float = 1.0,
dropout_p: float = 0.0, dropout_p: float = 0.0,
is_causal: bool = True, scale: float = 1.0,
causal=None,
padded=None,
): ):
""" """
The scaled dot product attention. The scaled dot product attention.
@ -39,3 +43,18 @@ def npu_sdpa_attention(
) )
output = rearrange(output, "b h s d -> b s (h d)") output = rearrange(output, "b h s d -> b s (h d)")
return output return output
class NpuSdpaAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self) -> bool:
return False
def build(self):
pass
def load(self):
return npu_sdpa_attention

View File

@ -13,18 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import torch import torch
from einops import rearrange from einops import rearrange
from ..base_extension import BaseExtension
from ..utils import print_rank_0
HAS_NPU_TRIANGLE_ATTENTION = False HAS_NPU_TRIANGLE_ATTENTION = False
try: try:
from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax
HAS_NPU_TRIANGLE_ATTENTION = True HAS_NPU_TRIANGLE_ATTENTION = True
except ImportError: except ImportError:
logging.warning("Import torch_npu Error.") pass
if HAS_NPU_TRIANGLE_ATTENTION: if HAS_NPU_TRIANGLE_ATTENTION:
@ -33,11 +35,13 @@ if HAS_NPU_TRIANGLE_ATTENTION:
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
attn_mask: torch.Tensor = None, seq_len_info_q=None,
seq_len_info_kv=None,
origin_attn_mask: torch.Tensor = None, origin_attn_mask: torch.Tensor = None,
scale: float = 1.0,
dropout_p: float = 0.0, dropout_p: float = 0.0,
is_causal: bool = True, scale: float = 1.0,
causal=None,
padded=None,
block_size=512, block_size=512,
): ):
""" """
@ -113,3 +117,25 @@ if HAS_NPU_TRIANGLE_ATTENTION:
# Context layer. [b, sq, hp] # Context layer. [b, sq, hp]
# ========================= # =========================
return context_layer return context_layer
class NpuTriangleAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self) -> bool:
return False
def build(self):
pass
def is_available(self):
if HAS_NPU_TRIANGLE_ATTENTION == False:
print_rank_0(
"ImportError: please install latest torch_npu with 'npu_confusion_transpose' and 'npu_scaled_masked_softmax' api."
)
return HAS_NPU_TRIANGLE_ATTENTION
def load(self):
return npu_triangle_attention

View File

@ -1,3 +1,4 @@
import enum
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, Tuple from typing import Iterable, Tuple
@ -80,3 +81,9 @@ class SeqLenInfo:
max_seqlen = max(seqlens) max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3

View File

@ -0,0 +1,229 @@
import os
import re
import subprocess
import warnings
from typing import List
def print_rank_0(message: str) -> None:
"""
Print on only one process to avoid spamming.
"""
try:
import torch.distributed as dist
if not dist.is_initialized():
is_main_rank = True
else:
is_main_rank = dist.get_rank() == 0
except ImportError:
is_main_rank = True
if is_main_rank:
print(message)
def get_cuda_version_in_pytorch() -> List[int]:
"""
This function returns the CUDA version in the PyTorch build.
Returns:
The CUDA version required by PyTorch, in the form of tuple (major, minor).
"""
import torch
try:
torch_cuda_major = torch.version.cuda.split(".")[0]
torch_cuda_minor = torch.version.cuda.split(".")[1]
except:
raise ValueError(
"[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda"
)
return torch_cuda_major, torch_cuda_minor
def get_cuda_bare_metal_version(cuda_dir) -> List[int]:
"""
Get the System CUDA version from nvcc.
Args:
cuda_dir (str): the directory for CUDA Toolkit.
Returns:
The CUDA version required by PyTorch, in the form of tuple (major, minor).
"""
nvcc_path = os.path.join(cuda_dir, "bin/nvcc")
if cuda_dir is None:
raise ValueError(
f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly."
)
# check for nvcc path
if not os.path.exists(nvcc_path):
raise FileNotFoundError(
f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME."
)
# parse the nvcc -v output to obtain the system cuda version
try:
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
except:
raise ValueError(
f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}"
)
return bare_metal_major, bare_metal_minor
def check_system_pytorch_cuda_match(cuda_dir):
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch()
if bare_metal_major != torch_cuda_major:
raise Exception(
f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) "
f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})."
"Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ."
)
if bare_metal_minor != torch_cuda_minor:
warnings.warn(
f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
"The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. "
"If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions"
)
return True
def get_pytorch_version() -> List[int]:
"""
This functions finds the PyTorch version.
Returns:
A tuple of integers in the form of (major, minor, patch).
"""
import torch
torch_version = torch.__version__.split("+")[0]
TORCH_MAJOR = int(torch_version.split(".")[0])
TORCH_MINOR = int(torch_version.split(".")[1])
TORCH_PATCH = int(torch_version.split(".")[2], 16)
return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH
def check_pytorch_version(min_major_version, min_minor_version) -> bool:
"""
Compare the current PyTorch version with the minium required version.
Args:
min_major_version (int): the minimum major version of PyTorch required
min_minor_version (int): the minimum minor version of PyTorch required
Returns:
A boolean value. The value is True if the current pytorch version is acceptable and False otherwise.
"""
# get pytorch version
torch_major, torch_minor, _ = get_pytorch_version()
# if the
if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version):
raise RuntimeError(
f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n"
"The latest stable release can be obtained from https://pytorch.org/get-started/locally/"
)
def check_cuda_availability():
"""
Check if CUDA is available on the system.
Returns:
A boolean value. True if CUDA is available and False otherwise.
"""
import torch
return torch.cuda.is_available()
def set_cuda_arch_list(cuda_dir):
"""
This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation.
Ahead-of-time compilation occurs when CUDA_EXT=1 is set when running 'pip install'.
"""
cuda_available = check_cuda_availability()
# we only need to set this when CUDA is not available for cross-compilation
if not cuda_available:
warnings.warn(
"\n[extension] PyTorch did not find available GPUs on this system.\n"
"If your intention is to cross-compile, this is not an error.\n"
"By default, Colossal-AI will cross-compile for \n"
"1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
"2. Volta (compute capability 7.0)\n"
"3. Turing (compute capability 7.5),\n"
"4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n"
"\nIf you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n'
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"]
if int(bare_metal_major) == 11:
if int(bare_metal_minor) == 0:
arch_list.append("8.0")
else:
arch_list.append("8.0")
arch_list.append("8.6")
arch_list_str = ";".join(arch_list)
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str
return False
return True
def get_cuda_cc_flag() -> List[str]:
"""
This function produces the cc flags for your GPU arch
Returns:
The CUDA cc flags for compilation.
"""
# only import torch when needed
# this is to avoid importing torch when building on a machine without torch pre-installed
# one case is to build wheel for pypi release
import torch
cc_flag = []
max_arch = "".join(str(i) for i in torch.cuda.get_device_capability())
for arch in torch.cuda.get_arch_list():
res = re.search(r"sm_(\d+)", arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch):
cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"])
return cc_flag
def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:
"""
This function appends the threads flag to your nvcc args.
Returns:
The nvcc compilation flags including the threads flag.
"""
from torch.utils.cpp_extension import CUDA_HOME
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args

View File

@ -0,0 +1,185 @@
import math
from collections import OrderedDict
from typing import Optional
import torch
from einops import rearrange
from colossalai.accelerator import get_accelerator
from .base_kernel_loader import BaseKernelLoader
from .extensions.flash_attention import (
AttnMaskType,
CudaFlashAttnExtension,
CudaMemoryEfficentAttnExtension,
NpuSdpaAttnExtension,
NpuTriangleAttnExtension,
Repad,
SeqLenInfo,
Unpad,
)
from .extensions.utils import print_rank_0
class FlashAttentionLoader(BaseKernelLoader):
"""
FlashAttention Loader
options: cuda flashh attention, cuda memory effcient attention, npu sdpa attention, npu triangle attention
Args:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
def __init__(self):
super().__init__(
# extension name must start with the accelerator name. E.g. npu_xxx, cuda_xxx
extension_map=OrderedDict(
cuda_flash_attn=CudaFlashAttnExtension,
cuda_memory_efficent_attn=CudaMemoryEfficentAttnExtension,
npu_sdpa_attn=NpuSdpaAttnExtension,
npu_triangle_attn=NpuTriangleAttnExtension,
),
supported_device=["cuda", "npu"],
)
def fetch_kernel(self, backend: str = None):
if backend is not None:
if not self._extension_map[backend]().is_available():
raise Exception(f"{backend} is not available for flash attention.")
return self._extension_map[backend]().fetch()
kernel = None
accelerator_name = get_accelerator().name
assert accelerator_name in self._supported_device, f"{accelerator_name} is not supported for flash attention."
for extension_name, extension in self._extension_map.items():
if extension_name.startswith(accelerator_name):
if extension().is_available():
kernel = extension().fetch()
break
if kernel is None:
raise Exception("No extension for flash attention is supported")
return kernel
class ColoAttention(torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
super().__init__()
assert (
embed_dim % num_heads == 0
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
self.attn = FlashAttentionLoader().fetch_kernel()
@staticmethod
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
return Unpad.apply(tensor, indices)
@staticmethod
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
return Repad.apply(tensor, indices, batch_size, seq_len)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: Optional[AttnMaskType] = None,
bias: Optional[torch.Tensor] = None,
):
"""
ColoAttention
Args:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
origin_attn_mask: (nheads, q_seqlen, kv_seqlen)
bias: will not be used
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
# if flash attention is not applicable, switch to memory effcient attention
if self.attn.__name__ == "flash_attention" and (
query.dtype not in [torch.float16, torch.bfloat16] or bias != None
):
print_rank_0("flash attention is not applicable, switch to memory effcient attention")
self.attn = FlashAttentionLoader().fetch_kernel(backend="cuda_memory_efficent_attn")
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
causal = attn_mask_type is not None and attn_mask_type.value > 1
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
# unpad
seq_len_info_q = None
seq_len_info_kv = None
if padded:
# bert style, unpad process
assert (
attn_mask is not None
), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
assert attn_mask.dim() == 2, (
"attention mask is supposed to have shape (batch_size, seq_len), "
+ f"but got {attn_mask.dim()} dimensions."
)
# bert style
if tgt_len == src_len:
seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
if batch_size > 1:
query, key, value = self.unpad(
torch.stack([query, key, value], dim=2), seq_len_info_q.indices
).unbind(dim=1)
else:
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
seq_len_info_kv = seq_len_info_q
else:
seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
if batch_size > 1:
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
dim=1
)
else:
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
out = self.attn(
query,
key,
value,
seq_len_info_q=seq_len_info_q,
seq_len_info_kv=seq_len_info_kv,
origin_attn_mask=origin_attn_mask,
dropout_p=self.dropout,
scale=self.scale,
causal=causal,
padded=padded,
)
# repad
if padded:
if batch_size > 1:
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
if len(out.shape) == 4:
out = rearrange(out, "b s h d -> b s (h d)")
return out

View File

@ -1,3 +0,0 @@
from .mha import NPUColoAttention
__all__ = ["NPUColoAttention"]

View File

@ -1,3 +0,0 @@
from .mha import NPUColoAttention
__all__ = ["NPUColoAttention"]

View File

@ -1,80 +0,0 @@
import math
from typing import Optional
import torch
from .sdpa_attn import npu_sdpa_attention
from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION
class NPUColoAttention(torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None):
super().__init__()
try:
import torch_npu # noqa
except ImportError:
raise Exception("torch_npu is not installed.")
assert (
embed_dim % num_heads == 0
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: int = None,
bias: Optional[torch.Tensor] = None,
):
"""
Implement the scaled dot product attention with softmax.
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
scale: float. The scaling of QK^T before applying softmax.
Default to 1.
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
assert (
len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4
), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}"
assert (
query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu"
), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}"
assert bias is None, "bias is not supported in npu colo attention"
causal = attn_mask_type is not None and attn_mask_type.value > 1
if HAS_NPU_TRIANGLE_ATTENTION:
from .triangle_attn import npu_triangle_attention
attn_fn = npu_triangle_attention
else:
attn_fn = npu_sdpa_attention
out = attn_fn(
query,
key,
value,
attn_mask=attn_mask,
origin_attn_mask=origin_attn_mask,
dropout_p=self.dropout,
scale=self.scale,
is_causal=causal,
)
return out

View File

@ -1,10 +1,9 @@
import math import math
import platform
from typing import Optional from typing import Optional
import torch import torch
from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder from colossalai.kernel import CPUAdamLoader
from .nvme_optimizer import NVMeOptimizer from .nvme_optimizer import NVMeOptimizer
@ -78,7 +77,7 @@ class CPUAdam(NVMeOptimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode self.adamw_mode = adamw_mode
cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load() cpu_adam = CPUAdamLoader().load()
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)

View File

@ -6,7 +6,8 @@ import torch.distributed as dist
from torch import nn from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size from torch.distributed import ProcessGroup, get_world_size
from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed
from colossalai.utils.device import get_current_device, get_rng_state, manual_seed, set_rng_state
class SeqParallelUtils: class SeqParallelUtils:
@ -280,21 +281,3 @@ def create_randomizer_with_offset(
Randomizer.increment_index() Randomizer.increment_index()
return Randomizer(seed=base_seed) return Randomizer(seed=base_seed)
def get_attention_kernel():
"""
Get the attention kernel based on the device type.
"""
from colossalai.kernel.cuda_native import AttnMaskType
if torch.cuda.is_available():
from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel
else:
try:
torch.npu.is_available()
from colossalai.kernel.npu import NPUColoAttention as AttentionKernel
except:
raise Exception("No available device for attention kernel!")
return AttnMaskType, AttentionKernel

View File

@ -62,7 +62,7 @@ def forward_fn():
def get_blip2_flash_attention_forward(): def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.kernel.cuda_native import ColoAttention from colossalai.kernel import ColoAttention
def forward( def forward(
self: Blip2Attention, self: Blip2Attention,

View File

@ -14,7 +14,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
def get_flash_core_attention_forward(): def get_flash_core_attention_forward():
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.kernel import AttnMaskType, ColoAttention
from .chatglm2_6b.modeling_chatglm import CoreAttention from .chatglm2_6b.modeling_chatglm import CoreAttention

View File

@ -719,7 +719,7 @@ class GPT2PipelineForwards:
def get_gpt2_flash_attention_forward(): def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.kernel import AttnMaskType, ColoAttention
def split_heads(tensor, num_heads, attn_head_size): def split_heads(tensor, num_heads, attn_head_size):
""" """

View File

@ -1,5 +1,5 @@
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple
import torch import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -12,14 +12,15 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.utils import get_attention_kernel
try: try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
LATEST_VERSION = True LATEST_VERSION = True
except ImportError: except ImportError:
LATEST_VERSION = False LATEST_VERSION = False
class LlamaPipelineForwards: class LlamaPipelineForwards:
""" """
This class serves as a micro library for forward function substitution of Llama models This class serves as a micro library for forward function substitution of Llama models
@ -405,7 +406,7 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward(): def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
AttnMaskType, ColoAttention = get_attention_kernel() from colossalai.kernel import AttnMaskType, ColoAttention
llama_version = 2 llama_version = 2
try: try:
@ -469,7 +470,12 @@ def get_llama_flash_attention_forward():
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention( attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask, query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type,
origin_attn_mask=attention_mask,
) )
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)

View File

@ -514,7 +514,7 @@ class OPTPipelineForwards:
def get_opt_flash_attention_forward(): def get_opt_flash_attention_forward():
from transformers.models.opt.modeling_opt import OPTAttention from transformers.models.opt.modeling_opt import OPTAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.kernel import AttnMaskType, ColoAttention
def forward( def forward(
self: OPTAttention, self: OPTAttention,

View File

@ -336,7 +336,7 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
def get_vit_flash_self_attention_forward(): def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention from transformers.models.vit.modeling_vit import ViTSelfAttention
from colossalai.kernel.cuda_native import ColoAttention from colossalai.kernel import ColoAttention
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)

View File

@ -26,7 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
def get_whisper_flash_attention_forward(): def get_whisper_flash_attention_forward():
from transformers.models.whisper.modeling_whisper import WhisperAttention from transformers.models.whisper.modeling_whisper import WhisperAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.kernel import AttnMaskType, ColoAttention
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()

View File

@ -35,7 +35,7 @@ from transformers.utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe.layers import SparseMLP from colossalai.moe.layers import SparseMLP
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER

View File

@ -90,9 +90,9 @@ class FusedAdamKernel(AdamKernel):
class CPUAdamKernel(AdamKernel): class CPUAdamKernel(AdamKernel):
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
from colossalai.kernel.op_builder import CPUAdamBuilder from colossalai.kernel import CPUAdamLoader
cpu_optim = CPUAdamBuilder().load() cpu_optim = CPUAdamLoader().load()
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw)

View File

@ -4,13 +4,11 @@ import pytest
import torch import torch
from einops import rearrange from einops import rearrange
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN
from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
from colossalai.testing import clear_cache_before_run, parameterize from colossalai.testing import clear_cache_before_run, parameterize
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native import ColoAttention from colossalai.kernel import AttnMaskType, ColoAttention
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
DTYPE = [torch.float16, torch.bfloat16, torch.float32] DTYPE = [torch.float16, torch.bfloat16, torch.float32]