[npu] change device to accelerator api (#5239)

* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-01-09 10:20:05 +08:00
committed by GitHub
parent dd2c28a323
commit d202cc28c0
128 changed files with 1773 additions and 868 deletions

View File

@@ -3,6 +3,7 @@ from collections import OrderedDict
from typing import Union
from .base_accelerator import BaseAccelerator
from .cpu_accelerator import CpuAccelerator
from .cuda_accelerator import CudaAccelerator
from .npu_accelerator import NpuAccelerator
@@ -15,7 +16,7 @@ _ACCELERATOR = None
# we use ordered dictionary here to associate the
# order with device check priority
# i.e. auto_set_accelerator will check cuda first
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator)
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator)
def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
@@ -43,19 +44,17 @@ def auto_set_accelerator() -> None:
"""
global _ACCELERATOR
for _, accelerator_cls in _ACCELERATOR_MAPPING.items():
for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items():
try:
accelerator = accelerator_cls()
if accelerator.is_available():
if accelerator_name == "cpu" or accelerator.is_available():
_ACCELERATOR = accelerator
break
break
except:
pass
if _ACCELERATOR is None:
raise RuntimeError(
f"No accelerator is available. Please check your environment. The list of accelerators we support is {list(_ACCELERATOR_MAPPING.keys())}"
)
raise RuntimeError("No accelerator is available.")
def get_accelerator() -> BaseAccelerator: