mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[npu] add npu support for hybrid plugin and llama (#5090)
* llama 3d * update * fix autocast
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.cuda.amp as torch_amp
|
||||
from colossalai.utils.device import autocast
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn.modules.loss import _Loss
|
||||
@@ -70,7 +71,7 @@ class TorchAMPModel(nn.Module):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
@torch_amp.autocast()
|
||||
@autocast()
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Execute forward under the torch amp context
|
||||
@@ -89,7 +90,7 @@ class TorchAMPLoss(nn.Module):
|
||||
super().__init__()
|
||||
self.loss = loss
|
||||
|
||||
@torch_amp.autocast()
|
||||
@autocast()
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Execute forward under the torch amp context
|
||||
|
Reference in New Issue
Block a user