mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-25 01:40:08 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			139 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			139 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import importlib
 | |
| import os
 | |
| import time
 | |
| from abc import abstractmethod
 | |
| from pathlib import Path
 | |
| from typing import List
 | |
| 
 | |
| from .base_extension import _Extension
 | |
| 
 | |
| __all__ = ["_CppExtension"]
 | |
| 
 | |
| 
 | |
| class _CppExtension(_Extension):
 | |
|     def __init__(self, name: str, priority: int = 1):
 | |
|         super().__init__(name, support_aot=True, support_jit=True, priority=priority)
 | |
| 
 | |
|         # we store the op as an attribute to avoid repeated building and loading
 | |
|         self.cached_op = None
 | |
| 
 | |
|         # build-related variables
 | |
|         self.prebuilt_module_path = "colossalai._C"
 | |
|         self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}"
 | |
|         self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
 | |
| 
 | |
|     def csrc_abs_path(self, path):
 | |
|         return os.path.join(self.relative_to_abs_path("csrc"), path)
 | |
| 
 | |
|     def pybind_abs_path(self, path):
 | |
|         return os.path.join(self.relative_to_abs_path("pybind"), path)
 | |
| 
 | |
|     def relative_to_abs_path(self, code_path: str) -> str:
 | |
|         """
 | |
|         This function takes in a path relative to the colossalai root directory and return the absolute path.
 | |
|         """
 | |
| 
 | |
|         # get the current file path
 | |
|         # iteratively check the parent directory
 | |
|         # if the parent directory is "extensions", then the current file path is the root directory
 | |
|         # otherwise, the current file path is inside the root directory
 | |
|         current_file_path = Path(__file__)
 | |
|         while True:
 | |
|             if current_file_path.name == "extensions":
 | |
|                 break
 | |
|             else:
 | |
|                 current_file_path = current_file_path.parent
 | |
|         extension_module_path = current_file_path
 | |
|         code_abs_path = extension_module_path.joinpath(code_path)
 | |
|         return str(code_abs_path)
 | |
| 
 | |
|     # functions must be overrided over
 | |
|     def strip_empty_entries(self, args):
 | |
|         """
 | |
|         Drop any empty strings from the list of compile and link flags
 | |
|         """
 | |
|         return [x for x in args if len(x) > 0]
 | |
| 
 | |
|     def import_op(self):
 | |
|         """
 | |
|         This function will import the op module by its string name.
 | |
|         """
 | |
|         return importlib.import_module(self.prebuilt_import_path)
 | |
| 
 | |
|     def build_aot(self) -> "CppExtension":
 | |
|         from torch.utils.cpp_extension import CppExtension
 | |
| 
 | |
|         return CppExtension(
 | |
|             name=self.prebuilt_import_path,
 | |
|             sources=self.strip_empty_entries(self.sources_files()),
 | |
|             include_dirs=self.strip_empty_entries(self.include_dirs()),
 | |
|             extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
 | |
|         )
 | |
| 
 | |
|     def build_jit(self) -> None:
 | |
|         from torch.utils.cpp_extension import load
 | |
| 
 | |
|         build_directory = _Extension.get_jit_extension_folder_path()
 | |
|         build_directory = Path(build_directory)
 | |
|         build_directory.mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
|         # check if the kernel has been built
 | |
|         compiled_before = False
 | |
|         kernel_file_path = build_directory.joinpath(f"{self.name}.so")
 | |
|         if kernel_file_path.exists():
 | |
|             compiled_before = True
 | |
| 
 | |
|         # load the kernel
 | |
|         if compiled_before:
 | |
|             print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
 | |
|         else:
 | |
|             print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
 | |
| 
 | |
|         build_start = time.time()
 | |
|         op_kernel = load(
 | |
|             name=self.name,
 | |
|             sources=self.strip_empty_entries(self.sources_files()),
 | |
|             extra_include_paths=self.strip_empty_entries(self.include_dirs()),
 | |
|             extra_cflags=self.cxx_flags(),
 | |
|             extra_ldflags=[],
 | |
|             build_directory=str(build_directory),
 | |
|         )
 | |
|         build_duration = time.time() - build_start
 | |
| 
 | |
|         if compiled_before:
 | |
|             print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
 | |
|         else:
 | |
|             print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
 | |
| 
 | |
|         return op_kernel
 | |
| 
 | |
|     # functions must be overrided begin
 | |
|     @abstractmethod
 | |
|     def sources_files(self) -> List[str]:
 | |
|         """
 | |
|         This function should return a list of source files for extensions.
 | |
|         """
 | |
| 
 | |
|     @abstractmethod
 | |
|     def include_dirs(self) -> List[str]:
 | |
|         """
 | |
|         This function should return a list of include files for extensions.
 | |
|         """
 | |
|         return [self.csrc_abs_path("")]
 | |
| 
 | |
|     @abstractmethod
 | |
|     def cxx_flags(self) -> List[str]:
 | |
|         """
 | |
|         This function should return a list of cxx compilation flags for extensions.
 | |
|         """
 | |
| 
 | |
|     def load(self):
 | |
|         try:
 | |
|             op_kernel = self.import_op()
 | |
|         except (ImportError, ModuleNotFoundError):
 | |
|             # if import error occurs, it means that the kernel is not pre-built
 | |
|             # so we build it jit
 | |
|             op_kernel = self.build_jit()
 | |
| 
 | |
|         return op_kernel
 |