[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
This commit is contained in:
Frank Lee
2024-01-25 17:01:48 +08:00
committed by GitHub
parent d7f8db8e21
commit 7cfed5f076
157 changed files with 1353 additions and 8966 deletions

View File

@@ -0,0 +1 @@
../../extensions

View File

@@ -1,21 +0,0 @@
from abc import ABC, abstractmethod
from typing import Callable
class BaseExtension(ABC):
@abstractmethod
def requires_build(self) -> bool:
pass
@abstractmethod
def build(self) -> None:
pass
@abstractmethod
def load(self) -> Callable:
pass
def fetch(self) -> Callable:
if self.requires_build:
self.build()
return self.load()

View File

@@ -1,4 +0,0 @@
from .arm_extension import ArmCPUAdamExtension
from .x86_extension import X86CPUAdamExtension
__all__ = ["ArmCPUAdamExtension", "X86CPUAdamExtension"]

View File

@@ -1,53 +0,0 @@
from ..base_extension import BaseExtension
from ..extension_builder import ExtensionBuilder
class ArmCPUAdamExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
self.kernel_builder = ArmCPUAdamBuilder()
self._requires_build = False
@property
def requires_build(self) -> bool:
return self._requires_build
def build(self):
self.kernel_builder.build()
self._requires_build = True
def load(self):
return self.kernel_builder.load()
class ArmCPUAdamBuilder(ExtensionBuilder):
NAME = "arm_cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam"
ext_type = "cpu"
def __init__(self):
super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam_arm.cpp"),
]
return ret
def include_dirs(self):
return [self.csrc_abs_path("includes")]
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-g",
"-Wno-reorder",
"-fopenmp",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
return []

View File

@@ -1,65 +0,0 @@
from ..base_extension import BaseExtension
from ..extension_builder import ExtensionBuilder
from ..utils import append_nvcc_threads
class X86CPUAdamExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
self.kernel_builder = X86CPUAdamBuilder()
self._requires_build = False
@property
def requires_build(self) -> bool:
return self._requires_build
def build(self):
self.kernel_builder.build()
self._requires_build = True
def load(self):
return self.kernel_builder.load()
class X86CPUAdamBuilder(ExtensionBuilder):
NAME = "cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam"
def __init__(self):
super().__init__(name=X86CPUAdamBuilder.NAME, prebuilt_import_path=X86CPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam.cpp"),
]
return ret
def include_dirs(self):
return [self.csrc_abs_path("includes"), self.get_cuda_home_include()]
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-lcudart",
"-lcublas",
"-g",
"-Wno-reorder",
"-fopenmp",
"-march=native",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
extra_cuda_flags = [
"-std=c++14",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
]
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)

View File

@@ -1,243 +0,0 @@
# This code has been adapted from the DeepSpeed library.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import importlib
import os
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
class ExtensionBuilder(ABC):
"""
Builder is the base class to build extensions for PyTorch.
Args:
name (str): the name of the kernel to be built
prebuilt_import_path (str): the path where the extension is installed during pip install
"""
ext_type: str = "cuda"
def __init__(self, name: str, prebuilt_import_path: str):
self.name = name
self.prebuilt_import_path = prebuilt_import_path
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# we store the op as an attribute to avoid repeated building and loading
self.cached_op_module = None
assert prebuilt_import_path.startswith(
"colossalai._C"
), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}"
def relative_to_abs_path(self, code_path: str) -> str:
"""
This function takes in a path relative to the colossalai root directory and return the absolute path.
"""
op_builder_module_path = Path(__file__).parent
# if we install from source
# the current file path will be op_builder/builder.py
# if we install via pip install colossalai
# the current file path will be colossalai/kernel/op_builder/builder.py
# this is because that the op_builder inside colossalai is a symlink
# this symlink will be replaced with actual files if we install via pypi
# thus we cannot tell the colossalai root directory by checking whether the op_builder
# is a symlink, we can only tell whether it is inside or outside colossalai
if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"):
root_path = op_builder_module_path.parent.parent
elif str(op_builder_module_path).endswith("colossalai/kernel/extensions"):
root_path = op_builder_module_path.parent.parent
else:
root_path = op_builder_module_path.parent.joinpath("colossalai")
code_abs_path = root_path.joinpath(code_path)
return str(code_abs_path)
def get_cuda_home_include(self):
"""
return include path inside the cuda home.
"""
from torch.utils.cpp_extension import CUDA_HOME
if CUDA_HOME is None:
raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
cuda_include = os.path.join(CUDA_HOME, "include")
return cuda_include
def csrc_abs_path(self, path):
return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path)
# functions must be overrided begin
@abstractmethod
def sources_files(self) -> List[str]:
"""
This function should return a list of source files for extensions.
"""
raise NotImplementedError
@abstractmethod
def include_dirs(self) -> List[str]:
"""
This function should return a list of include files for extensions.
"""
@abstractmethod
def cxx_flags(self) -> List[str]:
"""
This function should return a list of cxx compilation flags for extensions.
"""
@abstractmethod
def nvcc_flags(self) -> List[str]:
"""
This function should return a list of nvcc compilation flags for extensions.
"""
# functions must be overrided over
def strip_empty_entries(self, args):
"""
Drop any empty strings from the list of compile and link flags
"""
return [x for x in args if len(x) > 0]
def import_op(self):
"""
This function will import the op module by its string name.
"""
return importlib.import_module(self.prebuilt_import_path)
def check_runtime_build_environment(self):
"""
Check whether the system environment is ready for extension compilation.
"""
try:
from torch.utils.cpp_extension import CUDA_HOME
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
CUDA_HOME = None
if not TORCH_AVAILABLE:
raise ModuleNotFoundError(
"PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions"
)
if CUDA_HOME is None:
raise RuntimeError(
"CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions"
)
# make sure CUDA is available for compilation during
cuda_available = check_cuda_availability()
if not cuda_available:
raise RuntimeError("CUDA is not available on your system as torch.cuda.is_available() returns False.")
# make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not
check_system_pytorch_cuda_match(CUDA_HOME)
def build(self, verbose: Optional[bool] = None):
"""
If the kernel is not built during pip install, it will build the kernel.
If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the
kernel is built during pip install, it can be accessed through `colossalai._C`.
Warning: do not load this kernel repeatedly during model execution as it could slow down the training process.
Args:
verbose (bool, optional): show detailed info. Defaults to True.
"""
if verbose is None:
verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1"
try:
# if the kernel has been pre-built during installation
# we just directly import it
op_module = self.import_op()
if verbose:
print_rank_0(
f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building."
)
except ImportError:
# check environment
if self.ext_type == "cuda":
self.check_runtime_build_environment()
# time the kernel compilation
start_build = time.time()
# construct the build directory
import torch
from torch.utils.cpp_extension import load
torch_version_major = torch.__version__.split(".")[0]
torch_version_minor = torch.__version__.split(".")[1]
torch_cuda_version = torch.version.cuda
home_directory = os.path.expanduser("~")
extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}"
build_directory = os.path.join(home_directory, extension_directory)
Path(build_directory).mkdir(parents=True, exist_ok=True)
if verbose:
print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now")
# load the kernel
op_module = load(
name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_cuda_cflags=self.nvcc_flags(),
extra_ldflags=[],
build_directory=build_directory,
verbose=verbose,
)
build_duration = time.time() - start_build
# log jit compilation time
if verbose:
print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds")
# cache the built/loaded kernel
self.cached_op_module = op_module
def load(self, verbose: Optional[bool] = None):
"""
load the kernel during runtime.
Args:
verbose (bool, optional): show detailed info. Defaults to True.
"""
# if the kernel has be compiled and cached, we directly use it
assert self.cached_op_module is not None, "Please build the kernel first before loading it."
return self.cached_op_module
def builder(self) -> Union["CUDAExtension", "CppExtension"]:
"""
get a CUDAExtension instance used for setup.py
"""
from torch.utils.cpp_extension import CppExtension, CUDAExtension
if self.ext_type == "cpp":
return CppExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
)
return CUDAExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args={
"cxx": self.strip_empty_entries(self.cxx_flags()),
"nvcc": self.strip_empty_entries(self.nvcc_flags()),
},
)

View File

@@ -1,19 +0,0 @@
from .cuda_flash_attn_2_extension import HAS_FLASH_ATTN, CudaFlashAttnExtension
from .cuda_memory_efficient_attn_extension import HAS_MEM_EFF_ATTN, CudaMemoryEfficentAttnExtension
from .npu_sdpa_attn_extension import NpuSdpaAttnExtension
from .npu_triangle_attn_extension import HAS_NPU_TRIANGLE_ATTENTION, NpuTriangleAttnExtension
from .utils import AttnMaskType, Repad, SeqLenInfo, Unpad
__all__ = [
"CudaFlashAttnExtension",
"CudaMemoryEfficentAttnExtension",
"NpuSdpaAttnExtension",
"NpuTriangleAttnExtension",
"HAS_FLASH_ATTN",
"HAS_MEM_EFF_ATTN",
"HAS_NPU_TRIANGLE_ATTENTION",
"Unpad",
"AttnMaskType",
"Repad",
"SeqLenInfo",
]

View File

@@ -1,100 +0,0 @@
from typing import Optional
import torch
from ..base_extension import BaseExtension
from ..utils import print_rank_0
from .utils import SeqLenInfo
def is_ampere_or_better_gpu():
# Check Ampere GPUs or newer
if torch.cuda.is_available():
device = torch.device("cuda")
properties = torch.cuda.get_device_properties(device)
if properties.major >= 8: # Ampere GPUs or newer
return True
return False
HAS_FLASH_ATTN = False
ERROR_MSG = None
if is_ampere_or_better_gpu():
try:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
HAS_FLASH_ATTN = True
except ImportError:
ERROR_MSG = "ImportError: please install flash_attn from https://github.com/HazyResearch/flash-attention"
else:
ERROR_MSG = "ImportError: FlashAttention only supports Ampere GPUs or newer."
if HAS_FLASH_ATTN:
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: SeqLenInfo,
seq_len_info_kv: SeqLenInfo,
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
"""
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
if padded:
if seq_len_info_kv == None:
seq_len_info_kv = seq_len_info_q
attn_out = flash_attn_varlen_func(
q,
k,
v,
seq_len_info_q.cu_seqlens,
seq_len_info_kv.cu_seqlens,
seq_len_info_q.max_seqlen,
seq_len_info_kv.max_seqlen,
dropout_p,
scale,
causal,
)
else:
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
return attn_out
class CudaFlashAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self):
return False
def build(self):
pass
def is_available(self):
if HAS_FLASH_ATTN == False:
print_rank_0(ERROR_MSG)
return HAS_FLASH_ATTN
def load(self):
return flash_attention

View File

@@ -1,91 +0,0 @@
from typing import Optional
import torch
from ..base_extension import BaseExtension
from ..utils import print_rank_0
from .utils import SeqLenInfo
HAS_MEM_EFF_ATTN = False
try:
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
HAS_MEM_EFF_ATTN = True
except ImportError:
pass
if HAS_MEM_EFF_ATTN:
"""
A general attention module using the flash attention kernels from xformers:
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
"""
allow_alibi = True
for op in MemoryEfficientAttentionCutlassOp:
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
def mem_eff_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: SeqLenInfo,
seq_len_info_kv: SeqLenInfo,
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
attn_bias = None
if padded: # bert style
if not causal:
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
elif causal: # gpt style
attn_bias = LowerTriangularMask()
if bias is not None: # alibi / relative position embedding
assert allow_alibi, "flash attention with bias is not supported in this system."
assert causal, "attention with bias is only supported for causal attention so far."
attn_bias = attn_bias.add_bias(bias)
if padded:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
# shape: (b*s, n, d)
if padded:
out = out.squeeze(0)
return out
class CudaMemoryEfficentAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self) -> bool:
return False
def build(self):
pass
def is_available(self):
if HAS_MEM_EFF_ATTN == False:
print_rank_0("ImportError: please install xformers from https://github.com/facebookresearch/xformers")
return HAS_MEM_EFF_ATTN
def load(self):
return mem_eff_attention

View File

@@ -1,60 +0,0 @@
import torch
from einops import rearrange
from ..base_extension import BaseExtension
def npu_sdpa_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q=None,
seq_len_info_kv=None,
origin_attn_mask: torch.Tensor = None,
dropout_p: float = 0.0,
scale: float = 1.0,
causal=None,
padded=None,
):
"""
The scaled dot product attention.
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
scale: float. The scaling of QK^T before applying softmax.
Default to 1.
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=origin_attn_mask,
dropout_p=dropout_p,
is_causal=origin_attn_mask is None,
scale=scale,
)
output = rearrange(output, "b h s d -> b s (h d)")
return output
class NpuSdpaAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self) -> bool:
return False
def build(self):
pass
def load(self):
return npu_sdpa_attention

View File

@@ -1,141 +0,0 @@
# coding=utf-8
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from einops import rearrange
from ..base_extension import BaseExtension
from ..utils import print_rank_0
HAS_NPU_TRIANGLE_ATTENTION = False
try:
from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax
HAS_NPU_TRIANGLE_ATTENTION = True
except ImportError:
pass
if HAS_NPU_TRIANGLE_ATTENTION:
def npu_triangle_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q=None,
seq_len_info_kv=None,
origin_attn_mask: torch.Tensor = None,
dropout_p: float = 0.0,
scale: float = 1.0,
causal=None,
padded=None,
block_size=512,
):
"""
The triangle attention reduces the attention calculation of the mask
part by dividing the q, k, and v matrices into blocks
Arguments:
block_size: The size of the inverted triangle block, the default is 512,
the smaller the block_size, the more calculations will be reduced,
but the number of small operators will be increased
masked_softmax_func: mask function to be applied.
dropout_func: dropout function to be applied.
"""
def compute_attn(q_layer, k_layer, v_layer, mask_tmp):
# [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size]
cur_sim = torch.matmul(q_layer, k_layer)
attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp)
# attention dropout
if dropout_p > 0:
attention_probs = torch.nn.functional.dropout(
attention_probs, p=dropout_p, training=attention_probs.require_grad
)
# [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd]
context_layer_tmp = torch.matmul(attention_probs, v_layer)
return context_layer_tmp
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)]
origin_attn_mask = origin_attn_mask.to(torch.bool)
# input shape: [b, hn, sq, hd]
bsz, head_num, sequence_len, head_dim = k.shape
sparse_groups = sequence_len // block_size
# Determine whether blocks size can be divided by sequence_length
divisible_flag = sequence_len == block_size * sparse_groups
k = k.transpose(2, 3).contiguous()
if divisible_flag:
q_tmp_layers = torch.chunk(q, sparse_groups, 2)
k_tmp_layers = torch.chunk(k, sparse_groups, 3)
v_tmp_layers = torch.chunk(v, sparse_groups, 2)
else:
seq_tmp = block_size * sparse_groups
q_last = q[:, :, seq_tmp:, :].contiguous()
mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous()
q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2)
k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3)
v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2)
context_list_tmp, k_tmp, v_tmp = [], (), ()
for i in range(sparse_groups):
# compute slice shape of q k v for each loop
q_begin, q_end = i * block_size, (i + 1) * block_size
kv_begin, kv_end = 0, (i + 1) * block_size
q_tmp = q_tmp_layers[i]
# slice k and v
if i == 0:
k_tmp = k_tmp_layers[i].contiguous()
v_tmp = v_tmp_layers[i].contiguous()
else:
k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous()
v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous()
mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous()
context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp)
context_list_tmp.append(context_layer_tmp)
if not divisible_flag:
# circumstances that cannot be divisible
context_layer_tmp = compute_attn(q_last, k, v, mask_last)
context_list_tmp.append(context_layer_tmp)
context_layer = torch.cat(context_list_tmp, 2)
new_context_layer_shape = (bsz, sequence_len, head_num * head_dim)
context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True)
# =========================
# Context layer. [b, sq, hp]
# =========================
return context_layer
class NpuTriangleAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self) -> bool:
return False
def build(self):
pass
def is_available(self):
if HAS_NPU_TRIANGLE_ATTENTION == False:
print_rank_0(
"ImportError: please install latest torch_npu with 'npu_confusion_transpose' and 'npu_scaled_masked_softmax' api."
)
return HAS_NPU_TRIANGLE_ATTENTION
def load(self):
return npu_triangle_attention

View File

@@ -1,91 +0,0 @@
import enum
from dataclasses import dataclass
from typing import Iterable, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from colossalai.accelerator import get_accelerator
class Unpad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
ctx.save_for_backward(indices)
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape
# [ntokens, ...]
return out[indices]
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
class Repad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
ctx.save_for_backward(indices)
# [ntokens, ...]
tensor = tensor
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
return out
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
@dataclass
class SeqLenInfo:
seqlens: Iterable[int] = None
indices: torch.Tensor = None
max_seqlen: int = None
cu_seqlens: torch.Tensor = None
@staticmethod
def materialize(
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3

View File

@@ -1,229 +0,0 @@
import os
import re
import subprocess
import warnings
from typing import List
def print_rank_0(message: str) -> None:
"""
Print on only one process to avoid spamming.
"""
try:
import torch.distributed as dist
if not dist.is_initialized():
is_main_rank = True
else:
is_main_rank = dist.get_rank() == 0
except ImportError:
is_main_rank = True
if is_main_rank:
print(message)
def get_cuda_version_in_pytorch() -> List[int]:
"""
This function returns the CUDA version in the PyTorch build.
Returns:
The CUDA version required by PyTorch, in the form of tuple (major, minor).
"""
import torch
try:
torch_cuda_major = torch.version.cuda.split(".")[0]
torch_cuda_minor = torch.version.cuda.split(".")[1]
except:
raise ValueError(
"[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda"
)
return torch_cuda_major, torch_cuda_minor
def get_cuda_bare_metal_version(cuda_dir) -> List[int]:
"""
Get the System CUDA version from nvcc.
Args:
cuda_dir (str): the directory for CUDA Toolkit.
Returns:
The CUDA version required by PyTorch, in the form of tuple (major, minor).
"""
nvcc_path = os.path.join(cuda_dir, "bin/nvcc")
if cuda_dir is None:
raise ValueError(
f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly."
)
# check for nvcc path
if not os.path.exists(nvcc_path):
raise FileNotFoundError(
f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME."
)
# parse the nvcc -v output to obtain the system cuda version
try:
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
except:
raise ValueError(
f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}"
)
return bare_metal_major, bare_metal_minor
def check_system_pytorch_cuda_match(cuda_dir):
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch()
if bare_metal_major != torch_cuda_major:
raise Exception(
f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) "
f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})."
"Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ."
)
if bare_metal_minor != torch_cuda_minor:
warnings.warn(
f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
"The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. "
"If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions"
)
return True
def get_pytorch_version() -> List[int]:
"""
This functions finds the PyTorch version.
Returns:
A tuple of integers in the form of (major, minor, patch).
"""
import torch
torch_version = torch.__version__.split("+")[0]
TORCH_MAJOR = int(torch_version.split(".")[0])
TORCH_MINOR = int(torch_version.split(".")[1])
TORCH_PATCH = int(torch_version.split(".")[2], 16)
return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH
def check_pytorch_version(min_major_version, min_minor_version) -> bool:
"""
Compare the current PyTorch version with the minium required version.
Args:
min_major_version (int): the minimum major version of PyTorch required
min_minor_version (int): the minimum minor version of PyTorch required
Returns:
A boolean value. The value is True if the current pytorch version is acceptable and False otherwise.
"""
# get pytorch version
torch_major, torch_minor, _ = get_pytorch_version()
# if the
if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version):
raise RuntimeError(
f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n"
"The latest stable release can be obtained from https://pytorch.org/get-started/locally/"
)
def check_cuda_availability():
"""
Check if CUDA is available on the system.
Returns:
A boolean value. True if CUDA is available and False otherwise.
"""
import torch
return torch.cuda.is_available()
def set_cuda_arch_list(cuda_dir):
"""
This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation.
Ahead-of-time compilation occurs when CUDA_EXT=1 is set when running 'pip install'.
"""
cuda_available = check_cuda_availability()
# we only need to set this when CUDA is not available for cross-compilation
if not cuda_available:
warnings.warn(
"\n[extension] PyTorch did not find available GPUs on this system.\n"
"If your intention is to cross-compile, this is not an error.\n"
"By default, Colossal-AI will cross-compile for \n"
"1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
"2. Volta (compute capability 7.0)\n"
"3. Turing (compute capability 7.5),\n"
"4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n"
"\nIf you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n'
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"]
if int(bare_metal_major) == 11:
if int(bare_metal_minor) == 0:
arch_list.append("8.0")
else:
arch_list.append("8.0")
arch_list.append("8.6")
arch_list_str = ";".join(arch_list)
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str
return False
return True
def get_cuda_cc_flag() -> List[str]:
"""
This function produces the cc flags for your GPU arch
Returns:
The CUDA cc flags for compilation.
"""
# only import torch when needed
# this is to avoid importing torch when building on a machine without torch pre-installed
# one case is to build wheel for pypi release
import torch
cc_flag = []
max_arch = "".join(str(i) for i in torch.cuda.get_device_capability())
for arch in torch.cuda.get_arch_list():
res = re.search(r"sm_(\d+)", arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch):
cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"])
return cc_flag
def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:
"""
This function appends the threads flag to your nvcc args.
Returns:
The nvcc compilation flags including the threads flag.
"""
from torch.utils.cpp_extension import CUDA_HOME
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args