mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['Accelerator']
|
||||
__all__ = ["Accelerator"]
|
||||
|
||||
_supported_devices = [
|
||||
'cpu',
|
||||
'cuda',
|
||||
|
||||
"cpu",
|
||||
"cuda",
|
||||
# To be supported
|
||||
# 'xpu',
|
||||
# 'npu',
|
||||
@@ -25,21 +24,22 @@ class Accelerator:
|
||||
def __init__(self, device: str):
|
||||
self.device = device
|
||||
|
||||
assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
|
||||
assert (
|
||||
self.device in _supported_devices
|
||||
), f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
|
||||
|
||||
def bind(self):
|
||||
"""
|
||||
Set the default device for the current process.
|
||||
"""
|
||||
if self.device == 'cpu':
|
||||
if self.device == "cpu":
|
||||
pass
|
||||
elif self.device == 'cuda':
|
||||
elif self.device == "cuda":
|
||||
# TODO(FrankLeeeee): use global environment to check if it is a dist job
|
||||
# if is_distributed:
|
||||
# local_rank = EnvTable().get_local_rank()
|
||||
# torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
|
||||
torch.cuda.set_device(torch.device('cuda'))
|
||||
pass
|
||||
torch.cuda.set_device(torch.device("cuda"))
|
||||
else:
|
||||
raise ValueError(f"Device {self.device} is not supported yet")
|
||||
|
||||
|
Reference in New Issue
Block a user