[npu] add npu support for hybrid plugin and llama (#5090)

* llama 3d

* update

* fix autocast
This commit is contained in:
Xuanlei Zhao
2023-11-22 19:23:21 +08:00
committed by GitHub
parent aae496631c
commit 3acbf6d496
9 changed files with 61 additions and 40 deletions

View File

@@ -6,6 +6,7 @@ from torch import Tensor
from torch.optim import Optimizer
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils.device import autocast
from .mixed_precision_base import MixedPrecision
@@ -88,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
super().__init__(module)
def forward(self, *args, **kwargs):
with torch.cuda.amp.autocast():
with autocast():
return self.module(*args, **kwargs)