[npu] add npu support for hybrid plugin and llama (#5090)

* llama 3d

* update

* fix autocast
This commit is contained in:
Xuanlei Zhao
2023-11-22 19:23:21 +08:00
committed by GitHub
parent aae496631c
commit 3acbf6d496
9 changed files with 61 additions and 40 deletions

View File

@@ -7,7 +7,7 @@ import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
from colossalai.utils import get_current_device
from colossalai.utils.device import autocast, get_current_device
def copy_to_device(obj, device):
@@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
inputs[idx] = tensors[i]
detached_inputs = detach_variable(tuple(inputs))
if ctx.had_autocast_in_fwd:
with torch.enable_grad(), torch.cuda.amp.autocast():
with torch.enable_grad(), autocast():
outputs = ctx.run_function(*detached_inputs)
else:
with torch.enable_grad():
@@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# rerun forward, the inner_pack will store all the activations in storage
if has_autocast_in_fwd:
with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks(
with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks(
inner_pack, inner_unpack
):
_unused = function(*args)