mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 15:11:20 +00:00
[accelerator] init the accelerator module (#5129)
* [accelerator] init the accelerator module * polish code * polish code * polish code * polish code
This commit is contained in:
63
colossalai/accelerator/npu_accelerator.py
Normal file
63
colossalai/accelerator/npu_accelerator.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
|
||||
try:
|
||||
import torch_npu # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["NpuAccelerator"]
|
||||
|
||||
|
||||
class NpuAccelerator(BaseAccelerator):
|
||||
"""
|
||||
Accelerator class for Huawei NPU devices.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="npu", communication_backend="hccl", is_synchronous=False)
|
||||
|
||||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
Return the current device index.
|
||||
"""
|
||||
return torch.npu.current_device()
|
||||
|
||||
def set_device(self, device: Union[torch.device, int]) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
torch.npu.set_device(device)
|
||||
|
||||
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||||
"""
|
||||
Return the name of the device.
|
||||
"""
|
||||
return torch.npu.get_device_name(device)
|
||||
|
||||
def synchronize(self, device: Union[torch.device, int] = None):
|
||||
"""
|
||||
Synchronize the current process.
|
||||
"""
|
||||
torch.npu.synchronize(device)
|
||||
|
||||
def is_available(self):
|
||||
"""
|
||||
Check if the accelerator is available.
|
||||
"""
|
||||
return torch.npu.is_available()
|
||||
|
||||
def device_count(self):
|
||||
"""
|
||||
Return the number of devices on the machine.
|
||||
"""
|
||||
return torch.npu.device_count()
|
Reference in New Issue
Block a user