mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[npu] change device to accelerator api (#5239)
* update accelerator * fix timer * fix amp * update * fix * update bug * add error raise * fix autocast * fix set device * remove doc accelerator * update doc * update doc * update doc * use nullcontext * update cpu * update null context * change time limit for example * udpate * update * update * update * [npu] polish accelerator code --------- Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com> Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from collections import OrderedDict
|
||||
from typing import Union
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
from .cpu_accelerator import CpuAccelerator
|
||||
from .cuda_accelerator import CudaAccelerator
|
||||
from .npu_accelerator import NpuAccelerator
|
||||
|
||||
@@ -15,7 +16,7 @@ _ACCELERATOR = None
|
||||
# we use ordered dictionary here to associate the
|
||||
# order with device check priority
|
||||
# i.e. auto_set_accelerator will check cuda first
|
||||
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator)
|
||||
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator)
|
||||
|
||||
|
||||
def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
|
||||
@@ -43,19 +44,17 @@ def auto_set_accelerator() -> None:
|
||||
"""
|
||||
global _ACCELERATOR
|
||||
|
||||
for _, accelerator_cls in _ACCELERATOR_MAPPING.items():
|
||||
for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items():
|
||||
try:
|
||||
accelerator = accelerator_cls()
|
||||
if accelerator.is_available():
|
||||
if accelerator_name == "cpu" or accelerator.is_available():
|
||||
_ACCELERATOR = accelerator
|
||||
break
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
if _ACCELERATOR is None:
|
||||
raise RuntimeError(
|
||||
f"No accelerator is available. Please check your environment. The list of accelerators we support is {list(_ACCELERATOR_MAPPING.keys())}"
|
||||
)
|
||||
raise RuntimeError("No accelerator is available.")
|
||||
|
||||
|
||||
def get_accelerator() -> BaseAccelerator:
|
||||
|
Reference in New Issue
Block a user