[npu] add npu support for hybrid plugin and llama (#5090)

* llama 3d

* update

* fix autocast
This commit is contained in:
Xuanlei Zhao
2023-11-22 19:23:21 +08:00
committed by GitHub
parent aae496631c
commit 3acbf6d496
9 changed files with 61 additions and 40 deletions

View File

@@ -1,7 +1,8 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.cuda.amp as torch_amp
from colossalai.utils.device import autocast
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.loss import _Loss
@@ -70,7 +71,7 @@ class TorchAMPModel(nn.Module):
super().__init__()
self.model = model
@torch_amp.autocast()
@autocast()
def forward(self, *args, **kwargs):
"""
Execute forward under the torch amp context
@@ -89,7 +90,7 @@ class TorchAMPLoss(nn.Module):
super().__init__()
self.loss = loss
@torch_amp.autocast()
@autocast()
def forward(self, *args, **kwargs):
"""
Execute forward under the torch amp context