[Device]Support npu (#6159)

* support npu

* support pretrain

support pretrain

fix

* support lora

fix

fix

* support chatglm

fix

fxi

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

fix

* Update train.py

* Update train.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-17 15:42:39 +08:00
committed by GitHub
parent e994c64568
commit aaafb38851
18 changed files with 295 additions and 152 deletions

View File

@@ -1,15 +1,28 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import numbers
import warnings
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn.parameter import Parameter
from colossalai.lazy import LazyInitContext
from ._operation import hook_parameter_in_backward
from .utils import SeqParallelUtils
SUPPORT_NPU = False
try:
import torch_npu
SUPPORT_NPU = True
except Exception:
pass
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
try:
@@ -21,7 +34,6 @@ except ImportError:
try:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
@@ -32,7 +44,41 @@ try:
output = hook_parameter_in_backward(output, self.weight, self.bias)
return output
class FusedRMSNormWithHook(ApexFusedRMSNorm):
except ImportError:
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
FusedRMSNormWithHook = None
if SUPPORT_NPU:
class NPUFusedRMSNormWithHook(nn.Module):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.empty(*normalized_shape))
else:
self.register_parameter("weight", None)
self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine:
init.ones_(self.weight)
def forward(self, input):
output, _ = torch_npu.npu_rms_norm(input, self.weight, self.eps)
output = hook_parameter_in_backward(output, self.weight)
return output
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
else:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
@@ -41,8 +87,7 @@ try:
output = hook_parameter_in_backward(output, self.weight)
return output
except ImportError:
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel")
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,