mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[npu] use extension for op builder (#5172)
* update extension * update cpu adam * update is * add doc for cpu adam * update kernel * update commit * update flash * update memory efficient * update flash attn * update flash attention loader * update api * fix * update doc * update example time limit * reverse change * fix doc * remove useless kernel * fix * not use warning * update * update
This commit is contained in:
@@ -90,9 +90,9 @@ class FusedAdamKernel(AdamKernel):
|
||||
class CPUAdamKernel(AdamKernel):
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder
|
||||
from colossalai.kernel import CPUAdamLoader
|
||||
|
||||
cpu_optim = CPUAdamBuilder().load()
|
||||
cpu_optim = CPUAdamLoader().load()
|
||||
|
||||
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
|
||||
|
@@ -4,13 +4,11 @@ import pytest
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
|
||||
from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
|
||||
from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
|
||||
from colossalai.kernel.cuda_native import ColoAttention
|
||||
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
||||
from colossalai.kernel import AttnMaskType, ColoAttention
|
||||
|
||||
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
|
Reference in New Issue
Block a user