mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[accelerator] init the accelerator module (#5129)
* [accelerator] init the accelerator module * polish code * polish code * polish code * polish code
This commit is contained in:
56
colossalai/accelerator/cuda_accelerator.py
Normal file
56
colossalai/accelerator/cuda_accelerator.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
|
||||
__all__ = ["CudaAccelerator"]
|
||||
|
||||
|
||||
class CudaAccelerator(BaseAccelerator):
|
||||
"""
|
||||
Accelerator class for Nvidia CUDA devices.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False)
|
||||
|
||||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
Return the current device index.
|
||||
"""
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def set_device(self, device: Union[torch.device, int]) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||||
"""
|
||||
Return the name of the device.
|
||||
"""
|
||||
return torch.cuda.get_device_name(device)
|
||||
|
||||
def synchronize(self, device: Union[torch.device, int] = None):
|
||||
"""
|
||||
Synchronize the current process.
|
||||
"""
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
def is_available(self):
|
||||
"""
|
||||
Check if the accelerator is available.
|
||||
"""
|
||||
return torch.cuda.is_available()
|
||||
|
||||
def device_count(self):
|
||||
"""
|
||||
Return the number of devices on the machine.
|
||||
"""
|
||||
return torch.cuda.device_count()
|
Reference in New Issue
Block a user