mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[npu] change device to accelerator api (#5239)
* update accelerator * fix timer * fix amp * update * fix * update bug * add error raise * fix autocast * fix set device * remove doc accelerator * update doc * update doc * update doc * use nullcontext * update cpu * update null context * change time limit for example * udpate * update * update * update * [npu] polish accelerator code --------- Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com> Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
@@ -6,8 +6,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.device import IS_NPU_AVAILABLE
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
@@ -107,7 +106,7 @@ class Chunk:
|
||||
self.valid_end = self.shard_size
|
||||
|
||||
self.dtype = dtype
|
||||
device = init_device or get_current_device()
|
||||
device = init_device or get_accelerator().get_current_device()
|
||||
|
||||
# chunk_temp is a global chunk, which only exists during building the chunks.
|
||||
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
||||
@@ -125,7 +124,7 @@ class Chunk:
|
||||
# configure the init device of the shard
|
||||
# no-offload default: fp16, fp32 -> CUDA
|
||||
# offload default: fp16, fp32 -> CPU
|
||||
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
|
||||
self.shard_device = torch.device("cpu") if cpu_shard_init else get_accelerator().get_current_device()
|
||||
|
||||
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
|
||||
self.shard_mem = self.chunk_mem // self.pg_size
|
||||
@@ -192,10 +191,7 @@ class Chunk:
|
||||
if self.chunk_temp is not None:
|
||||
return self.chunk_temp.device.type
|
||||
else:
|
||||
if self.is_gathered or self.cuda_shard is not None:
|
||||
return "npu" if IS_NPU_AVAILABLE else "cuda"
|
||||
else:
|
||||
return "cpu"
|
||||
return get_accelerator().name
|
||||
|
||||
@property
|
||||
def payload(self) -> torch.Tensor:
|
||||
@@ -297,7 +293,7 @@ class Chunk:
|
||||
self.valid_end = self.utilized_size - self.shard_begin
|
||||
|
||||
if self.chunk_temp.device.type == "cpu":
|
||||
self.cuda_global_chunk = self.chunk_temp.to(get_current_device())
|
||||
self.cuda_global_chunk = self.chunk_temp.to(get_accelerator().get_current_device())
|
||||
self.__update_tensors_ptr()
|
||||
else:
|
||||
self.cuda_global_chunk = self.chunk_temp
|
||||
@@ -334,12 +330,12 @@ class Chunk:
|
||||
return
|
||||
|
||||
if device.type == "cuda" or device.type == "npu":
|
||||
assert device == get_current_device(), "can't move chunk to another device"
|
||||
assert device == get_accelerator().get_current_device(), "can't move chunk to another device"
|
||||
|
||||
if self.cuda_shard:
|
||||
return
|
||||
|
||||
self.cuda_shard = self.cpu_shard.to(get_current_device())
|
||||
self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device())
|
||||
|
||||
if not self.pin_memory:
|
||||
self.cpu_shard = None
|
||||
@@ -394,7 +390,9 @@ class Chunk:
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
|
||||
else:
|
||||
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
|
||||
self.cuda_shard = torch.empty(
|
||||
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
||||
)
|
||||
|
||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
||||
@@ -533,7 +531,7 @@ class Chunk:
|
||||
# only be called when optimizer state is in CPU memory
|
||||
# the grad and param should be in the same device
|
||||
assert self.cuda_shard is None
|
||||
temp = optim_chunk.cpu_shard.to(get_current_device())
|
||||
temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device())
|
||||
# avoid to transform FP32 in CPU
|
||||
self.cuda_shard = temp.to(self.dtype)
|
||||
|
||||
@@ -631,7 +629,7 @@ class Chunk:
|
||||
grad_chunk.valid_end = self.valid_end
|
||||
|
||||
if grad_chunk.chunk_temp.device.type == "cpu":
|
||||
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device())
|
||||
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_accelerator().get_current_device())
|
||||
else:
|
||||
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp
|
||||
grad_chunk.chunk_temp = None
|
||||
|
@@ -5,7 +5,8 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.utils import free_storage, get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.utils import free_storage
|
||||
|
||||
from .chunk import Chunk, ChunkFullError, TensorState
|
||||
|
||||
@@ -20,7 +21,7 @@ class ChunkManager:
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
|
||||
self.device = init_device or get_current_device()
|
||||
self.device = init_device or get_accelerator().get_current_device()
|
||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||
self.kwargs_config = chunk_configuration
|
||||
for k, v in self.kwargs_config.items():
|
||||
@@ -107,7 +108,7 @@ class ChunkManager:
|
||||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
if chunk.device_type == "cpu":
|
||||
chunk.shard_move(get_current_device())
|
||||
chunk.shard_move(get_accelerator().get_current_device())
|
||||
self.__add_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
@@ -276,7 +277,10 @@ class ChunkManager:
|
||||
accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size)
|
||||
else:
|
||||
accumulated_grad = (
|
||||
chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size)
|
||||
chunk.grad_chunk.cpu_shard.to(get_accelerator().get_current_device())
|
||||
.clone()
|
||||
.detach()
|
||||
.mul_(chunk.pg_size)
|
||||
)
|
||||
accumulated_grad_gathered = False
|
||||
|
||||
|
@@ -10,6 +10,7 @@ import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyTensor
|
||||
@@ -27,7 +28,7 @@ from colossalai.tensor.d_tensor import (
|
||||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
|
||||
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
||||
from .gemini_hook import GeminiZeROHook
|
||||
@@ -766,7 +767,7 @@ class GeminiDDP(ModelWrapper):
|
||||
|
||||
# move ignored parameters to CUDA
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
|
||||
p.data = p.data.to(device=get_accelerator().get_current_device(), dtype=self.mixed_precision)
|
||||
continue
|
||||
|
||||
# create a fp16 parameter
|
||||
@@ -815,7 +816,7 @@ class GeminiDDP(ModelWrapper):
|
||||
for buffer in self.module.buffers():
|
||||
if isinstance(buffer, LazyTensor):
|
||||
buffer.materialize()
|
||||
buffer.data = buffer.to(get_current_device())
|
||||
buffer.data = buffer.to(get_accelerator().get_current_device())
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.to(self.mixed_precision)
|
||||
|
||||
|
@@ -11,6 +11,7 @@ from torch.distributed import ProcessGroup
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
@@ -26,7 +27,7 @@ from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
from colossalai.utils import disposable, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .gemini_ddp import GeminiDDP
|
||||
@@ -233,7 +234,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
grad_chunk.l2_norm = None # clear l2 norm
|
||||
|
||||
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
|
||||
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
|
||||
for group, part_norm in group_to_norm.items():
|
||||
comm_buffer.fill_(part_norm)
|
||||
dist.all_reduce(comm_buffer, group=group)
|
||||
@@ -314,10 +315,10 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
continue
|
||||
|
||||
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
|
||||
self.chunk_manager.move_chunk(chunk32, get_current_device())
|
||||
self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device())
|
||||
# stores grad now
|
||||
self.chunk_manager.move_chunk(chunk16, get_current_device())
|
||||
self.module.set_chunk_grad_device(chunk16, get_current_device())
|
||||
self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device())
|
||||
self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device())
|
||||
fp32_params_used_cuda_margin_mem += chunk32.payload_mem
|
||||
|
||||
for group in self.param_groups:
|
||||
@@ -328,7 +329,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
state = self.optim.state[fake_param]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(get_current_device())
|
||||
state[k] = v.to(get_accelerator().get_current_device())
|
||||
|
||||
def _register_states_(self):
|
||||
for group in self.optim.param_groups:
|
||||
@@ -551,7 +552,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
self,
|
||||
param_id: int,
|
||||
state_names: list,
|
||||
device: torch.device = get_current_device(),
|
||||
device: torch.device = get_accelerator().get_current_device(),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.zero.gemini.chunk import ChunkManager
|
||||
|
||||
from .memory_stats import MemStats
|
||||
@@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector):
|
||||
def cuda_margin_mem(self) -> float:
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
|
||||
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda
|
||||
return colo_device_memory_capacity(get_accelerator().get_current_device()) - self._memstats.max_overall_cuda
|
||||
|
@@ -5,7 +5,7 @@ from time import sleep, time
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
@@ -77,7 +77,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
|
||||
super().__init__()
|
||||
self.keep_measuring = False
|
||||
|
||||
current_device = get_current_device()
|
||||
current_device = get_accelerator().get_current_device()
|
||||
|
||||
def _set_cuda_device():
|
||||
torch.cuda.set_device(current_device)
|
||||
@@ -116,7 +116,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
|
||||
while self.keep_measuring:
|
||||
max_usage = max(
|
||||
max_usage,
|
||||
colo_device_memory_used(get_current_device()),
|
||||
colo_device_memory_used(get_accelerator().get_current_device()),
|
||||
)
|
||||
sleep(self.interval)
|
||||
return max_usage
|
||||
|
@@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini.chunk import Chunk
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
@@ -85,7 +85,7 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||
# init offload optim settings
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
|
||||
@@ -140,7 +140,7 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
int: the volume of memory that is evicted
|
||||
"""
|
||||
start = time()
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
|
||||
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
|
||||
if warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
@@ -194,7 +194,7 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
# init offload optim settings
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk.keep_gathered:
|
||||
grads_device_map[p] = get_current_device()
|
||||
grads_device_map[p] = get_accelerator().get_current_device()
|
||||
else:
|
||||
grads_device_map[p] = torch.device("cpu")
|
||||
|
||||
|
@@ -6,7 +6,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .chunk import Chunk
|
||||
|
||||
@@ -18,11 +18,11 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype):
|
||||
if chunk.cuda_shard is not None:
|
||||
shard_temp = chunk.cuda_shard
|
||||
else:
|
||||
shard_temp = chunk.cpu_shard.to(get_current_device())
|
||||
shard_temp = chunk.cpu_shard.to(get_accelerator().get_current_device())
|
||||
|
||||
shard_temp = shard_temp.to(dtype)
|
||||
|
||||
total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device())
|
||||
total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
|
||||
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
|
||||
|
||||
|
Reference in New Issue
Block a user