mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
[npu] change device to accelerator api (#5239)
* update accelerator * fix timer * fix amp * update * fix * update bug * add error raise * fix autocast * fix set device * remove doc accelerator * update doc * update doc * update doc * use nullcontext * update cpu * update null context * change time limit for example * udpate * update * update * update * [npu] polish accelerator code --------- Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com> Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
@@ -7,12 +7,12 @@ from typing import Callable
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication import all_reduce
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.registry import HOOKS
|
||||
from colossalai.legacy.utils import is_no_pp_or_last_stage
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from ._base_hook import BaseHook
|
||||
from ._commons_ import _format_number
|
||||
@@ -82,8 +82,8 @@ class LossMetric(Metric):
|
||||
|
||||
def __init__(self, epoch_only):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.last_step_loss = torch.zeros(1, device=get_current_device())
|
||||
self.accum_loss = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.count = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -164,10 +164,10 @@ class AccuracyMetric(Metric):
|
||||
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.acc = accuracy_func
|
||||
self.last_step_sum = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_correct = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_sum = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_correct = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_sum = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.last_step_correct = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.accumulated_sum = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.accumulated_correct = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_step_sum.zero_()
|
||||
@@ -320,10 +320,10 @@ class ThroughputMetric(Metric):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.ignored_steps = ignored_steps
|
||||
self.cur_steps = 0
|
||||
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_used_time = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_num_samples = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_used_time = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.accumulated_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.last_step_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.last_step_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self._tflop_per_step = tflop_per_step
|
||||
self._use_local = use_local
|
||||
|
||||
|
Reference in New Issue
Block a user