mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[npu] change device to accelerator api (#5239)
* update accelerator * fix timer * fix amp * update * fix * update bug * add error raise * fix autocast * fix set device * remove doc accelerator * update doc * update doc * update doc * use nullcontext * update cpu * update null context * change time limit for example * udpate * update * update * update * [npu] polish accelerator code --------- Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com> Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
@@ -6,8 +6,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.device import IS_NPU_AVAILABLE
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
@@ -107,7 +106,7 @@ class Chunk:
|
||||
self.valid_end = self.shard_size
|
||||
|
||||
self.dtype = dtype
|
||||
device = init_device or get_current_device()
|
||||
device = init_device or get_accelerator().get_current_device()
|
||||
|
||||
# chunk_temp is a global chunk, which only exists during building the chunks.
|
||||
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
||||
@@ -125,7 +124,7 @@ class Chunk:
|
||||
# configure the init device of the shard
|
||||
# no-offload default: fp16, fp32 -> CUDA
|
||||
# offload default: fp16, fp32 -> CPU
|
||||
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
|
||||
self.shard_device = torch.device("cpu") if cpu_shard_init else get_accelerator().get_current_device()
|
||||
|
||||
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
|
||||
self.shard_mem = self.chunk_mem // self.pg_size
|
||||
@@ -192,10 +191,7 @@ class Chunk:
|
||||
if self.chunk_temp is not None:
|
||||
return self.chunk_temp.device.type
|
||||
else:
|
||||
if self.is_gathered or self.cuda_shard is not None:
|
||||
return "npu" if IS_NPU_AVAILABLE else "cuda"
|
||||
else:
|
||||
return "cpu"
|
||||
return get_accelerator().name
|
||||
|
||||
@property
|
||||
def payload(self) -> torch.Tensor:
|
||||
@@ -297,7 +293,7 @@ class Chunk:
|
||||
self.valid_end = self.utilized_size - self.shard_begin
|
||||
|
||||
if self.chunk_temp.device.type == "cpu":
|
||||
self.cuda_global_chunk = self.chunk_temp.to(get_current_device())
|
||||
self.cuda_global_chunk = self.chunk_temp.to(get_accelerator().get_current_device())
|
||||
self.__update_tensors_ptr()
|
||||
else:
|
||||
self.cuda_global_chunk = self.chunk_temp
|
||||
@@ -334,12 +330,12 @@ class Chunk:
|
||||
return
|
||||
|
||||
if device.type == "cuda" or device.type == "npu":
|
||||
assert device == get_current_device(), "can't move chunk to another device"
|
||||
assert device == get_accelerator().get_current_device(), "can't move chunk to another device"
|
||||
|
||||
if self.cuda_shard:
|
||||
return
|
||||
|
||||
self.cuda_shard = self.cpu_shard.to(get_current_device())
|
||||
self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device())
|
||||
|
||||
if not self.pin_memory:
|
||||
self.cpu_shard = None
|
||||
@@ -394,7 +390,9 @@ class Chunk:
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
|
||||
else:
|
||||
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
|
||||
self.cuda_shard = torch.empty(
|
||||
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
||||
)
|
||||
|
||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
||||
@@ -533,7 +531,7 @@ class Chunk:
|
||||
# only be called when optimizer state is in CPU memory
|
||||
# the grad and param should be in the same device
|
||||
assert self.cuda_shard is None
|
||||
temp = optim_chunk.cpu_shard.to(get_current_device())
|
||||
temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device())
|
||||
# avoid to transform FP32 in CPU
|
||||
self.cuda_shard = temp.to(self.dtype)
|
||||
|
||||
@@ -631,7 +629,7 @@ class Chunk:
|
||||
grad_chunk.valid_end = self.valid_end
|
||||
|
||||
if grad_chunk.chunk_temp.device.type == "cpu":
|
||||
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device())
|
||||
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_accelerator().get_current_device())
|
||||
else:
|
||||
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp
|
||||
grad_chunk.chunk_temp = None
|
||||
|
Reference in New Issue
Block a user