mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-23 20:43:19 +00:00
[npu] add npu support for gemini and zero (#5067)
* [npu] setup device utils (#5047) * [npu] add npu device support * [npu] support low level zero * [test] update npu zero plugin test * [hotfix] fix import * [test] recover tests * [npu] gemini support npu (#5052) * [npu] refactor device utils * [gemini] support npu * [example] llama2+gemini support npu * [kernel] add arm cpu adam kernel (#5065) * [kernel] add arm cpu adam * [optim] update adam optimizer * [kernel] arm cpu adam remove bf16 support
This commit is contained in:
@@ -7,7 +7,7 @@ import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
|
||||
|
||||
@@ -21,6 +21,8 @@ class Builder(ABC):
|
||||
prebuilt_import_path (str): the path where the extension is installed during pip install
|
||||
"""
|
||||
|
||||
ext_type: str = "cuda"
|
||||
|
||||
def __init__(self, name: str, prebuilt_import_path: str):
|
||||
self.name = name
|
||||
self.prebuilt_import_path = prebuilt_import_path
|
||||
@@ -165,7 +167,8 @@ class Builder(ABC):
|
||||
)
|
||||
except ImportError:
|
||||
# check environment
|
||||
self.check_runtime_build_environment()
|
||||
if self.ext_type == "cuda":
|
||||
self.check_runtime_build_environment()
|
||||
|
||||
# time the kernel compilation
|
||||
start_build = time.time()
|
||||
@@ -208,11 +211,19 @@ class Builder(ABC):
|
||||
|
||||
return op_module
|
||||
|
||||
def builder(self) -> "CUDAExtension":
|
||||
def builder(self) -> Union["CUDAExtension", "CppExtension"]:
|
||||
"""
|
||||
get a CUDAExtension instance used for setup.py
|
||||
"""
|
||||
from torch.utils.cpp_extension import CUDAExtension
|
||||
from torch.utils.cpp_extension import CppExtension, CUDAExtension
|
||||
|
||||
if self.ext_type == "cpp":
|
||||
return CppExtension(
|
||||
name=self.prebuilt_import_path,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
include_dirs=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
|
||||
)
|
||||
|
||||
return CUDAExtension(
|
||||
name=self.prebuilt_import_path,
|
||||
|
||||
Reference in New Issue
Block a user