mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[feat] refactored extension module (#5298)
* [feat] refactored extension module * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish
This commit is contained in:
134
extensions/cpp_extension.py
Normal file
134
extensions/cpp_extension.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import importlib
|
||||
import os
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from .base_extension import _Extension
|
||||
|
||||
__all__ = ["_CppExtension"]
|
||||
|
||||
|
||||
class _CppExtension(_Extension):
|
||||
def __init__(self, name: str, priority: int = 1):
|
||||
super().__init__(name, support_aot=True, support_jit=True, priority=priority)
|
||||
|
||||
# we store the op as an attribute to avoid repeated building and loading
|
||||
self.cached_op = None
|
||||
|
||||
# build-related variables
|
||||
self.prebuilt_module_path = "colossalai._C"
|
||||
self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}"
|
||||
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
|
||||
|
||||
def csrc_abs_path(self, path):
|
||||
return os.path.join(self.relative_to_abs_path("csrc"), path)
|
||||
|
||||
def relative_to_abs_path(self, code_path: str) -> str:
|
||||
"""
|
||||
This function takes in a path relative to the colossalai root directory and return the absolute path.
|
||||
"""
|
||||
|
||||
# get the current file path
|
||||
# iteratively check the parent directory
|
||||
# if the parent directory is "extensions", then the current file path is the root directory
|
||||
# otherwise, the current file path is inside the root directory
|
||||
current_file_path = Path(__file__)
|
||||
while True:
|
||||
if current_file_path.name == "extensions":
|
||||
break
|
||||
else:
|
||||
current_file_path = current_file_path.parent
|
||||
extension_module_path = current_file_path
|
||||
code_abs_path = extension_module_path.joinpath(code_path)
|
||||
return str(code_abs_path)
|
||||
|
||||
# functions must be overrided over
|
||||
def strip_empty_entries(self, args):
|
||||
"""
|
||||
Drop any empty strings from the list of compile and link flags
|
||||
"""
|
||||
return [x for x in args if len(x) > 0]
|
||||
|
||||
def import_op(self):
|
||||
"""
|
||||
This function will import the op module by its string name.
|
||||
"""
|
||||
return importlib.import_module(self.prebuilt_import_path)
|
||||
|
||||
def build_aot(self) -> "CppExtension":
|
||||
from torch.utils.cpp_extension import CppExtension
|
||||
|
||||
return CppExtension(
|
||||
name=self.prebuilt_import_path,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
include_dirs=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
|
||||
)
|
||||
|
||||
def build_jit(self) -> None:
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
build_directory = _Extension.get_jit_extension_folder_path()
|
||||
build_directory = Path(build_directory)
|
||||
build_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# check if the kernel has been built
|
||||
compiled_before = False
|
||||
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
|
||||
if kernel_file_path.exists():
|
||||
compiled_before = True
|
||||
|
||||
# load the kernel
|
||||
if compiled_before:
|
||||
print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
|
||||
else:
|
||||
print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
|
||||
|
||||
build_start = time.time()
|
||||
op_kernel = load(
|
||||
name=self.name,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_cflags=self.cxx_flags(),
|
||||
extra_ldflags=[],
|
||||
build_directory=str(build_directory),
|
||||
)
|
||||
build_duration = time.time() - build_start
|
||||
|
||||
if compiled_before:
|
||||
print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
|
||||
else:
|
||||
print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
|
||||
|
||||
return op_kernel
|
||||
|
||||
# functions must be overrided begin
|
||||
@abstractmethod
|
||||
def sources_files(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of source files for extensions.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def include_dirs(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of include files for extensions.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def cxx_flags(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of cxx compilation flags for extensions.
|
||||
"""
|
||||
|
||||
def load(self):
|
||||
try:
|
||||
op_kernel = self.import_op()
|
||||
except ImportError:
|
||||
# if import error occurs, it means that the kernel is not pre-built
|
||||
# so we build it jit
|
||||
op_kernel = self.build_jit()
|
||||
|
||||
return op_kernel
|
Reference in New Issue
Block a user