mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[Device]Support npu (#6159)
* support npu * support pretrain support pretrain fix * support lora fix fix * support chatglm fix fxi fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix fix * Update train.py * Update train.py * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,15 +1,28 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import numbers
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
|
||||
from ._operation import hook_parameter_in_backward
|
||||
from .utils import SeqParallelUtils
|
||||
|
||||
SUPPORT_NPU = False
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
SUPPORT_NPU = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
|
||||
|
||||
try:
|
||||
@@ -21,7 +34,6 @@ except ImportError:
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
|
||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||
|
||||
class FusedLayerNormWithHook(ApexFusedLayerNorm):
|
||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||
@@ -32,7 +44,41 @@ try:
|
||||
output = hook_parameter_in_backward(output, self.weight, self.bias)
|
||||
return output
|
||||
|
||||
class FusedRMSNormWithHook(ApexFusedRMSNorm):
|
||||
except ImportError:
|
||||
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
|
||||
|
||||
FusedRMSNormWithHook = None
|
||||
if SUPPORT_NPU:
|
||||
|
||||
class NPUFusedRMSNormWithHook(nn.Module):
|
||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||
super().__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = torch.Size(normalized_shape)
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = Parameter(torch.empty(*normalized_shape))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.elementwise_affine:
|
||||
init.ones_(self.weight)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
output, _ = torch_npu.npu_rms_norm(input, self.weight, self.eps)
|
||||
output = hook_parameter_in_backward(output, self.weight)
|
||||
return output
|
||||
|
||||
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
|
||||
else:
|
||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||
|
||||
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
|
||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||
super().__init__(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
@@ -41,8 +87,7 @@ try:
|
||||
output = hook_parameter_in_backward(output, self.weight)
|
||||
return output
|
||||
|
||||
except ImportError:
|
||||
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
|
||||
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
|
||||
|
||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||
1024,
|
||||
|
Reference in New Issue
Block a user