mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[npu] add npu support for gemini and zero (#5067)
* [npu] setup device utils (#5047) * [npu] add npu device support * [npu] support low level zero * [test] update npu zero plugin test * [hotfix] fix import * [test] recover tests * [npu] gemini support npu (#5052) * [npu] refactor device utils * [gemini] support npu * [example] llama2+gemini support npu * [kernel] add arm cpu adam kernel (#5065) * [kernel] add arm cpu adam * [optim] update adam optimizer * [kernel] arm cpu adam remove bf16 support
This commit is contained in:
@@ -84,9 +84,10 @@ class HybridAdam(CPUAdam):
|
||||
nvme_offload_fraction,
|
||||
nvme_offload_dir,
|
||||
)
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
if torch.cuda.is_available():
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None, div_scale: float = -1):
|
||||
@@ -118,11 +119,11 @@ class HybridAdam(CPUAdam):
|
||||
group_step = state["step"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
if target_device.type == "cpu":
|
||||
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
if target_device.type == "cpu" or target_device.type == "npu":
|
||||
assert state["exp_avg"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
|
||||
assert state["exp_avg_sq"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
|
||||
self._pre_update(p, "exp_avg", "exp_avg_sq")
|
||||
if p.grad.dtype is torch.bfloat16:
|
||||
if p.grad.dtype is torch.bfloat16 or p.grad.device.type == "npu":
|
||||
# cpu adam kernel does not support bf16 now
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
|
Reference in New Issue
Block a user