[npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support
This commit is contained in:
Hongxin Liu
2023-11-20 16:12:41 +08:00
committed by GitHub
parent 8d56c9c389
commit e5ce4c8ea6
46 changed files with 994 additions and 233 deletions

View File

@@ -2,16 +2,19 @@
# -*- encoding: utf-8 -*-
import warnings
from abc import ABC, abstractmethod
import torch.nn as nn
from colossalai.lazy import LazyInitContext
from ._operation import hook_paramter_in_backward
import torch.nn as nn
from colossalai.lazy import LazyInitContext
from ._operation import hook_paramter_in_backward
from .utils import SeqParallelUtils
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
EnableFastLayerNorm = True
except ImportError:
EnableFastLayerNorm = False
@@ -19,10 +22,27 @@ except ImportError:
try:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output
class FusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight)
return output
except ImportError:
warnings.warn(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel"
)
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel")
FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,
@@ -52,6 +72,7 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
]
if EnableFastLayerNorm:
class FastLayerNormWithHook(FastLayerNorm):
def __init__(self, hidden_size, eps=0.00001):
super().__init__(hidden_size, eps)
@@ -60,25 +81,7 @@ if EnableFastLayerNorm:
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output
class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output
class FusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight)
return output
class BaseLayerNorm(ABC):
@abstractmethod
@@ -244,12 +247,13 @@ class FusedRMSNorm(BaseLayerNorm):
"""
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"FusedRMSNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
)
@staticmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
@@ -264,7 +268,7 @@ class FusedRMSNorm(BaseLayerNorm):
nn.Module: FusedRMSNorm module.
"""
try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
pass
except ImportError:
raise ImportError(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
@@ -282,7 +286,9 @@ class FusedRMSNorm(BaseLayerNorm):
eps = module.eps
elementwise_affine = module.elementwise_affine
rmsnorm = FusedRMSNormWithHook(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
)
rmsnorm.weight = module.weight