[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
This commit is contained in:
Frank Lee
2024-01-25 17:01:48 +08:00
committed by GitHub
parent d7f8db8e21
commit 7cfed5f076
157 changed files with 1353 additions and 8966 deletions

View File

@@ -0,0 +1,21 @@
from .base_extension import _Extension
__all__ = ["_TritonExtension"]
class _TritonExtension(_Extension):
def __init__(self, name: str, priority: int = 1):
super().__init__(name, support_aot=False, support_jit=True, priority=priority)
def is_hardware_compatible(self) -> bool:
# cuda extension can only be built if cuda is availabe
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def load(self):
return self.build_jit()