mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[npu] add npu support for gemini and zero (#5067)
* [npu] setup device utils (#5047) * [npu] add npu device support * [npu] support low level zero * [test] update npu zero plugin test * [hotfix] fix import * [test] recover tests * [npu] gemini support npu (#5052) * [npu] refactor device utils * [gemini] support npu * [example] llama2+gemini support npu * [kernel] add arm cpu adam kernel (#5065) * [kernel] add arm cpu adam * [optim] update adam optimizer * [kernel] arm cpu adam remove bf16 support
This commit is contained in:
@@ -2,16 +2,19 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
import torch.nn as nn
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from ._operation import hook_paramter_in_backward
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
|
||||
from ._operation import hook_paramter_in_backward
|
||||
from .utils import SeqParallelUtils
|
||||
|
||||
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
|
||||
|
||||
try:
|
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||
|
||||
EnableFastLayerNorm = True
|
||||
except ImportError:
|
||||
EnableFastLayerNorm = False
|
||||
@@ -19,10 +22,27 @@ except ImportError:
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
|
||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||
|
||||
class FusedLayerNormWithHook(ApexFusedLayerNorm):
|
||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||
super().__init__(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
def forward(self, input):
|
||||
output = super().forward(input)
|
||||
output = hook_paramter_in_backward(output, self.weight, self.bias)
|
||||
return output
|
||||
|
||||
class FusedRMSNormWithHook(ApexFusedRMSNorm):
|
||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||
super().__init__(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
def forward(self, input):
|
||||
output = super().forward(input)
|
||||
output = hook_paramter_in_backward(output, self.weight)
|
||||
return output
|
||||
|
||||
except ImportError:
|
||||
warnings.warn(
|
||||
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel"
|
||||
)
|
||||
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel")
|
||||
|
||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||
1024,
|
||||
@@ -52,6 +72,7 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||
]
|
||||
|
||||
if EnableFastLayerNorm:
|
||||
|
||||
class FastLayerNormWithHook(FastLayerNorm):
|
||||
def __init__(self, hidden_size, eps=0.00001):
|
||||
super().__init__(hidden_size, eps)
|
||||
@@ -60,25 +81,7 @@ if EnableFastLayerNorm:
|
||||
output = super().forward(input)
|
||||
output = hook_paramter_in_backward(output, self.weight, self.bias)
|
||||
return output
|
||||
|
||||
class FusedLayerNormWithHook(ApexFusedLayerNorm):
|
||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||
super().__init__(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
def forward(self, input):
|
||||
output = super().forward(input)
|
||||
output = hook_paramter_in_backward(output, self.weight, self.bias)
|
||||
return output
|
||||
|
||||
class FusedRMSNormWithHook(ApexFusedRMSNorm):
|
||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||
super().__init__(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
def forward(self, input):
|
||||
output = super().forward(input)
|
||||
output = hook_paramter_in_backward(output, self.weight)
|
||||
return output
|
||||
|
||||
|
||||
class BaseLayerNorm(ABC):
|
||||
@abstractmethod
|
||||
@@ -244,12 +247,13 @@ class FusedRMSNorm(BaseLayerNorm):
|
||||
"""
|
||||
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"FusedRMSNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
r"""
|
||||
@@ -264,7 +268,7 @@ class FusedRMSNorm(BaseLayerNorm):
|
||||
nn.Module: FusedRMSNorm module.
|
||||
"""
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||
pass
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
|
||||
@@ -282,7 +286,9 @@ class FusedRMSNorm(BaseLayerNorm):
|
||||
eps = module.eps
|
||||
elementwise_affine = module.elementwise_affine
|
||||
|
||||
rmsnorm = FusedRMSNormWithHook(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
|
||||
rmsnorm = FusedRMSNormWithHook(
|
||||
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
|
||||
)
|
||||
|
||||
rmsnorm.weight = module.weight
|
||||
|
||||
|
Reference in New Issue
Block a user