[accelerator] init the accelerator module (#5129)

* [accelerator] init the accelerator module

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-11-30 13:25:17 +08:00
committed by GitHub
parent 68fcaa2225
commit f4e72c9992
7 changed files with 306 additions and 0 deletions

View File

@@ -0,0 +1,56 @@
#!/usr/bin/env python
from typing import Union
import torch
from .base_accelerator import BaseAccelerator
__all__ = ["CudaAccelerator"]
class CudaAccelerator(BaseAccelerator):
"""
Accelerator class for Nvidia CUDA devices.
"""
def __init__(self):
super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False)
# =======================
# device APIs
# =======================
def current_device(self) -> int:
"""
Return the current device index.
"""
return torch.cuda.current_device()
def set_device(self, device: Union[torch.device, int]) -> None:
"""
Bind the current process to a device.
"""
torch.cuda.set_device(device)
def get_device_name(self, device: Union[torch.device, int]) -> str:
"""
Return the name of the device.
"""
return torch.cuda.get_device_name(device)
def synchronize(self, device: Union[torch.device, int] = None):
"""
Synchronize the current process.
"""
torch.cuda.synchronize(device)
def is_available(self):
"""
Check if the accelerator is available.
"""
return torch.cuda.is_available()
def device_count(self):
"""
Return the number of devices on the machine.
"""
return torch.cuda.device_count()