mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-02 05:35:29 +00:00
* [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests
42 lines
1.0 KiB
Python
42 lines
1.0 KiB
Python
import platform
|
|
|
|
from ..cpp_extension import _CppExtension
|
|
|
|
|
|
class CpuAdamArmExtension(_CppExtension):
|
|
def __init__(self):
|
|
super().__init__(name="cpu_adam_arm")
|
|
|
|
def is_available(self) -> bool:
|
|
# only arm allowed
|
|
return platform.machine() == "aarch64"
|
|
|
|
def assert_compatible(self) -> None:
|
|
arch = platform.machine()
|
|
assert (
|
|
arch == "aarch64"
|
|
), f"[extension] The {self.name} kernel requires the CPU architecture to be aarch64 but got {arch}"
|
|
|
|
# necessary 4 functions
|
|
def sources_files(self):
|
|
ret = [
|
|
self.csrc_abs_path("arm/cpu_adam_arm.cpp"),
|
|
]
|
|
return ret
|
|
|
|
def include_dirs(self):
|
|
return []
|
|
|
|
def cxx_flags(self):
|
|
extra_cxx_flags = [
|
|
"-std=c++14",
|
|
"-std=c++17",
|
|
"-g",
|
|
"-Wno-reorder",
|
|
"-fopenmp",
|
|
]
|
|
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
|
|
|
|
def nvcc_flags(self):
|
|
return []
|