[npu] change device to accelerator api (#5239)

* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-01-09 10:20:05 +08:00
committed by GitHub
parent dd2c28a323
commit d202cc28c0
128 changed files with 1773 additions and 868 deletions

View File

@@ -11,6 +11,7 @@ from torch.distributed import ProcessGroup
from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import OptimizerWrapper
@@ -26,7 +27,7 @@ from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from colossalai.utils import disposable, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP
@@ -233,7 +234,7 @@ class GeminiOptimizer(OptimizerWrapper):
grad_chunk.l2_norm = None # clear l2 norm
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
for group, part_norm in group_to_norm.items():
comm_buffer.fill_(part_norm)
dist.all_reduce(comm_buffer, group=group)
@@ -314,10 +315,10 @@ class GeminiOptimizer(OptimizerWrapper):
continue
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
self.chunk_manager.move_chunk(chunk32, get_current_device())
self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device())
# stores grad now
self.chunk_manager.move_chunk(chunk16, get_current_device())
self.module.set_chunk_grad_device(chunk16, get_current_device())
self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device())
self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device())
fp32_params_used_cuda_margin_mem += chunk32.payload_mem
for group in self.param_groups:
@@ -328,7 +329,7 @@ class GeminiOptimizer(OptimizerWrapper):
state = self.optim.state[fake_param]
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(get_current_device())
state[k] = v.to(get_accelerator().get_current_device())
def _register_states_(self):
for group in self.optim.param_groups:
@@ -551,7 +552,7 @@ class GeminiOptimizer(OptimizerWrapper):
self,
param_id: int,
state_names: list,
device: torch.device = get_current_device(),
device: torch.device = get_accelerator().get_current_device(),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""