mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-15 14:47:16 +00:00
* update extension * update cpu adam * update is * add doc for cpu adam * update kernel * update commit * update flash * update memory efficient * update flash attn * update flash attention loader * update api * fix * update doc * update example time limit * reverse change * fix doc * remove useless kernel * fix * not use warning * update * update
29 lines
691 B
Python
29 lines
691 B
Python
from abc import ABC, abstractmethod
|
|
from typing import Dict, List
|
|
|
|
from .extensions.base_extension import BaseExtension
|
|
|
|
|
|
class BaseKernelLoader(ABC):
|
|
"""
|
|
Usage:
|
|
kernel_loader = KernelLoader()
|
|
kernel = kernel_loader.load()
|
|
"""
|
|
|
|
def __init__(self, extension_map: Dict[str, BaseExtension], supported_device: List[str]):
|
|
self._extension_map = extension_map
|
|
self._supported_device = supported_device
|
|
|
|
def run_checks(self):
|
|
# run supported device check and other possible checks
|
|
pass
|
|
|
|
@abstractmethod
|
|
def fetch_kernel(self):
|
|
pass
|
|
|
|
def load(self):
|
|
self.run_checks()
|
|
return self.fetch_kernel()
|