mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 13:42:12 +00:00
[npu] change device to accelerator api (#5239)
* update accelerator * fix timer * fix amp * update * fix * update bug * add error raise * fix autocast * fix set device * remove doc accelerator * update doc * update doc * update doc * use nullcontext * update cpu * update null context * change time limit for example * udpate * update * update * update * [npu] polish accelerator code --------- Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com> Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
@@ -5,7 +5,8 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.utils import free_storage, get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.utils import free_storage
|
||||
|
||||
from .chunk import Chunk, ChunkFullError, TensorState
|
||||
|
||||
@@ -20,7 +21,7 @@ class ChunkManager:
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
|
||||
self.device = init_device or get_current_device()
|
||||
self.device = init_device or get_accelerator().get_current_device()
|
||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||
self.kwargs_config = chunk_configuration
|
||||
for k, v in self.kwargs_config.items():
|
||||
@@ -107,7 +108,7 @@ class ChunkManager:
|
||||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
if chunk.device_type == "cpu":
|
||||
chunk.shard_move(get_current_device())
|
||||
chunk.shard_move(get_accelerator().get_current_device())
|
||||
self.__add_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
@@ -276,7 +277,10 @@ class ChunkManager:
|
||||
accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size)
|
||||
else:
|
||||
accumulated_grad = (
|
||||
chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size)
|
||||
chunk.grad_chunk.cpu_shard.to(get_accelerator().get_current_device())
|
||||
.clone()
|
||||
.detach()
|
||||
.mul_(chunk.pg_size)
|
||||
)
|
||||
accumulated_grad_gathered = False
|
||||
|
||||
|
Reference in New Issue
Block a user