mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[feat] refactored extension module (#5298)
* [feat] refactored extension module * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish
This commit is contained in:
36
extensions/__init__.py
Normal file
36
extensions/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension
|
||||
from .flash_attention import (
|
||||
FlashAttentionDaoCudaExtension,
|
||||
FlashAttentionNpuExtension,
|
||||
FlashAttentionXformersCudaExtension,
|
||||
)
|
||||
from .layernorm import LayerNormCudaExtension
|
||||
from .moe import MoeCudaExtension
|
||||
from .optimizer import FusedOptimizerCudaExtension
|
||||
from .softmax import ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension
|
||||
|
||||
ALL_EXTENSIONS = [
|
||||
CpuAdamArmExtension,
|
||||
CpuAdamX86Extension,
|
||||
LayerNormCudaExtension,
|
||||
MoeCudaExtension,
|
||||
FusedOptimizerCudaExtension,
|
||||
ScaledMaskedSoftmaxCudaExtension,
|
||||
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
|
||||
FlashAttentionDaoCudaExtension,
|
||||
FlashAttentionXformersCudaExtension,
|
||||
FlashAttentionNpuExtension,
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"CpuAdamArmExtension",
|
||||
"CpuAdamX86Extension",
|
||||
"LayerNormCudaExtension",
|
||||
"MoeCudaExtension",
|
||||
"FusedOptimizerCudaExtension",
|
||||
"ScaledMaskedSoftmaxCudaExtension",
|
||||
"ScaledUpperTriangleMaskedSoftmaxCudaExtension",
|
||||
"FlashAttentionDaoCudaExtension",
|
||||
"FlashAttentionXformersCudaExtension",
|
||||
"FlashAttentionNpuExtension",
|
||||
]
|
82
extensions/base_extension.py
Normal file
82
extensions/base_extension.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import hashlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
|
||||
__all__ = ["_Extension"]
|
||||
|
||||
|
||||
class _Extension(ABC):
|
||||
def __init__(self, name: str, support_aot: bool, support_jit: bool, priority: int = 1):
|
||||
self._name = name
|
||||
self._support_aot = support_aot
|
||||
self._support_jit = support_jit
|
||||
self.priority = priority
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def support_aot(self):
|
||||
return self._support_aot
|
||||
|
||||
@property
|
||||
def support_jit(self):
|
||||
return self._support_jit
|
||||
|
||||
@staticmethod
|
||||
def get_jit_extension_folder_path():
|
||||
"""
|
||||
Kernels which are compiled during runtime will be stored in the same cache folder for reuse.
|
||||
The folder is in the path ~/.cache/colossalai/torch_extensions/<cache-folder>.
|
||||
The name of the <cache-folder> follows a common format:
|
||||
torch<torch_version_major>.<torch_version_minor>_<device_name><device_version>-<hash>
|
||||
|
||||
The <hash> suffix is the hash value of the path of the `colossalai` file.
|
||||
"""
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
# get torch version
|
||||
torch_version_major = torch.__version__.split(".")[0]
|
||||
torch_version_minor = torch.__version__.split(".")[1]
|
||||
|
||||
# get device version
|
||||
device_name = get_accelerator().name
|
||||
device_version = get_accelerator().get_version()
|
||||
|
||||
# use colossalai's file path as hash
|
||||
hash_suffix = hashlib.sha256(colossalai.__file__.encode()).hexdigest()
|
||||
|
||||
# concat
|
||||
home_directory = os.path.expanduser("~")
|
||||
extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_{device_name}-{device_version}-{hash_suffix}"
|
||||
cache_directory = os.path.join(home_directory, extension_directory)
|
||||
return cache_directory
|
||||
|
||||
@abstractmethod
|
||||
def is_hardware_available(self) -> bool:
|
||||
"""
|
||||
Check if the hardware required by the kernel is available.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def assert_hardware_compatible(self) -> bool:
|
||||
"""
|
||||
Check if the hardware required by the kernel is compatible.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def build_aot(self) -> Union["CppExtension", "CUDAExtension"]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_jit(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
pass
|
134
extensions/cpp_extension.py
Normal file
134
extensions/cpp_extension.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import importlib
|
||||
import os
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from .base_extension import _Extension
|
||||
|
||||
__all__ = ["_CppExtension"]
|
||||
|
||||
|
||||
class _CppExtension(_Extension):
|
||||
def __init__(self, name: str, priority: int = 1):
|
||||
super().__init__(name, support_aot=True, support_jit=True, priority=priority)
|
||||
|
||||
# we store the op as an attribute to avoid repeated building and loading
|
||||
self.cached_op = None
|
||||
|
||||
# build-related variables
|
||||
self.prebuilt_module_path = "colossalai._C"
|
||||
self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}"
|
||||
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
|
||||
|
||||
def csrc_abs_path(self, path):
|
||||
return os.path.join(self.relative_to_abs_path("csrc"), path)
|
||||
|
||||
def relative_to_abs_path(self, code_path: str) -> str:
|
||||
"""
|
||||
This function takes in a path relative to the colossalai root directory and return the absolute path.
|
||||
"""
|
||||
|
||||
# get the current file path
|
||||
# iteratively check the parent directory
|
||||
# if the parent directory is "extensions", then the current file path is the root directory
|
||||
# otherwise, the current file path is inside the root directory
|
||||
current_file_path = Path(__file__)
|
||||
while True:
|
||||
if current_file_path.name == "extensions":
|
||||
break
|
||||
else:
|
||||
current_file_path = current_file_path.parent
|
||||
extension_module_path = current_file_path
|
||||
code_abs_path = extension_module_path.joinpath(code_path)
|
||||
return str(code_abs_path)
|
||||
|
||||
# functions must be overrided over
|
||||
def strip_empty_entries(self, args):
|
||||
"""
|
||||
Drop any empty strings from the list of compile and link flags
|
||||
"""
|
||||
return [x for x in args if len(x) > 0]
|
||||
|
||||
def import_op(self):
|
||||
"""
|
||||
This function will import the op module by its string name.
|
||||
"""
|
||||
return importlib.import_module(self.prebuilt_import_path)
|
||||
|
||||
def build_aot(self) -> "CppExtension":
|
||||
from torch.utils.cpp_extension import CppExtension
|
||||
|
||||
return CppExtension(
|
||||
name=self.prebuilt_import_path,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
include_dirs=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
|
||||
)
|
||||
|
||||
def build_jit(self) -> None:
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
build_directory = _Extension.get_jit_extension_folder_path()
|
||||
build_directory = Path(build_directory)
|
||||
build_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# check if the kernel has been built
|
||||
compiled_before = False
|
||||
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
|
||||
if kernel_file_path.exists():
|
||||
compiled_before = True
|
||||
|
||||
# load the kernel
|
||||
if compiled_before:
|
||||
print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
|
||||
else:
|
||||
print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
|
||||
|
||||
build_start = time.time()
|
||||
op_kernel = load(
|
||||
name=self.name,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_cflags=self.cxx_flags(),
|
||||
extra_ldflags=[],
|
||||
build_directory=str(build_directory),
|
||||
)
|
||||
build_duration = time.time() - build_start
|
||||
|
||||
if compiled_before:
|
||||
print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
|
||||
else:
|
||||
print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
|
||||
|
||||
return op_kernel
|
||||
|
||||
# functions must be overrided begin
|
||||
@abstractmethod
|
||||
def sources_files(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of source files for extensions.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def include_dirs(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of include files for extensions.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def cxx_flags(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of cxx compilation flags for extensions.
|
||||
"""
|
||||
|
||||
def load(self):
|
||||
try:
|
||||
op_kernel = self.import_op()
|
||||
except ImportError:
|
||||
# if import error occurs, it means that the kernel is not pre-built
|
||||
# so we build it jit
|
||||
op_kernel = self.build_jit()
|
||||
|
||||
return op_kernel
|
5
extensions/cpu_adam/__init__.py
Normal file
5
extensions/cpu_adam/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .cpu_adam_arm import CpuAdamArmExtension
|
||||
from .cpu_adam_x86 import CpuAdamX86Extension
|
||||
|
||||
__all__ = ['CpuAdamArmExtension', 'CpuAdamX86Extension']
|
||||
|
41
extensions/cpu_adam/cpu_adam_arm.py
Normal file
41
extensions/cpu_adam/cpu_adam_arm.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import platform
|
||||
|
||||
from ..cpp_extension import _CppExtension
|
||||
|
||||
|
||||
class CpuAdamArmExtension(_CppExtension):
|
||||
def __init__(self):
|
||||
super().__init__(name="cpu_adam_arm")
|
||||
|
||||
def is_hardware_available(self) -> bool:
|
||||
# only arm allowed
|
||||
return platform.machine() == "aarch64"
|
||||
|
||||
def assert_hardware_compatible(self) -> None:
|
||||
arch = platform.machine()
|
||||
assert (
|
||||
arch == "aarch64"
|
||||
), f"[extension] The {self.name} kernel requires the CPU architecture to be aarch64 but got {arch}"
|
||||
|
||||
# necessary 4 functions
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path("arm/cpu_adam_arm.cpp"),
|
||||
]
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
return []
|
||||
|
||||
def cxx_flags(self):
|
||||
extra_cxx_flags = [
|
||||
"-std=c++14",
|
||||
"-std=c++17",
|
||||
"-g",
|
||||
"-Wno-reorder",
|
||||
"-fopenmp",
|
||||
]
|
||||
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
|
||||
|
||||
def nvcc_flags(self):
|
||||
return []
|
54
extensions/cpu_adam/cpu_adam_x86.py
Normal file
54
extensions/cpu_adam/cpu_adam_x86.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import platform
|
||||
|
||||
from ..cuda_extension import _CudaExtension
|
||||
from ..utils import append_nvcc_threads
|
||||
|
||||
|
||||
class CpuAdamX86Extension(_CudaExtension):
|
||||
def __init__(self):
|
||||
super().__init__(name="cpu_adam_x86")
|
||||
|
||||
def is_hardware_available(self) -> bool:
|
||||
return platform.machine() == "x86_64" and super().is_hardware_available()
|
||||
|
||||
def assert_hardware_compatible(self) -> None:
|
||||
arch = platform.machine()
|
||||
assert (
|
||||
arch == "x86_64"
|
||||
), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}"
|
||||
super().assert_hardware_compatible()
|
||||
|
||||
# necessary 4 functions
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path("cuda/cpu_adam.cpp"),
|
||||
]
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
return [self.csrc_abs_path("includes"), self.get_cuda_home_include()]
|
||||
|
||||
def cxx_flags(self):
|
||||
extra_cxx_flags = [
|
||||
"-std=c++14",
|
||||
"-std=c++17",
|
||||
"-lcudart",
|
||||
"-lcublas",
|
||||
"-g",
|
||||
"-Wno-reorder",
|
||||
"-fopenmp",
|
||||
"-march=native",
|
||||
]
|
||||
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = [
|
||||
"-std=c++14",
|
||||
"-std=c++17",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
|
||||
]
|
||||
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
|
||||
return append_nvcc_threads(ret)
|
11
extensions/csrc/__init__.py
Normal file
11
extensions/csrc/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||
from .multihead_attention import MultiHeadAttention
|
||||
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
|
||||
|
||||
__all__ = [
|
||||
"LayerNorm",
|
||||
"MultiHeadAttention",
|
||||
"FusedScaleMaskSoftmax",
|
||||
"ScaledUpperTriangMaskedSoftmax",
|
||||
"AttnMaskType",
|
||||
]
|
304
extensions/csrc/arm/cpu_adam_arm.cpp
Normal file
304
extensions/csrc/arm/cpu_adam_arm.cpp
Normal file
@@ -0,0 +1,304 @@
|
||||
#include "cpu_adam_arm.h"
|
||||
|
||||
void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg,
|
||||
void *_exp_avg_sq, size_t _param_size,
|
||||
at::ScalarType param_dtype,
|
||||
at::ScalarType grad_dtype,
|
||||
at::ScalarType exp_avg_dtype,
|
||||
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__aarch64__)
|
||||
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
|
||||
#endif
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
|
||||
#if defined(__aarch64__)
|
||||
float32x4_t betta1_4 = simd_set(_betta1);
|
||||
float32x4_t betta2_4 = simd_set(_betta2);
|
||||
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
|
||||
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
|
||||
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
|
||||
float32x4_t eps_4 = simd_set(_eps);
|
||||
float32x4_t step_size_4 = simd_set(step_size);
|
||||
float32x4_t weight_decay_4;
|
||||
if (_weight_decay > 0) {
|
||||
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
|
||||
}
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
|
||||
float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i);
|
||||
if (loss_scale > 0) {
|
||||
float32x4_t loss_scale_vec = simd_set(loss_scale);
|
||||
grad_4 = vdivq_f32(grad_4, loss_scale_vec);
|
||||
}
|
||||
float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i);
|
||||
float32x4_t variance_4 =
|
||||
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i);
|
||||
float32x4_t param_4 = simd_load_offset(_params, param_dtype, i);
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4);
|
||||
}
|
||||
momentum_4 = vmulq_f32(momentum_4, betta1_4);
|
||||
momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4);
|
||||
variance_4 = vmulq_f32(variance_4, betta2_4);
|
||||
grad_4 = vmulq_f32(grad_4, grad_4);
|
||||
variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4);
|
||||
grad_4 = vsqrtq_f32(variance_4);
|
||||
grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt);
|
||||
grad_4 = vdivq_f32(momentum_4, grad_4);
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4 = vfmaq_f32(param_4, param_4, weight_decay_4);
|
||||
}
|
||||
param_4 = vfmaq_f32(param_4, grad_4, step_size_4);
|
||||
simd_store_offset(_params, param_dtype, param_4, i);
|
||||
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i);
|
||||
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = scalar_load_offset(grads, grad_dtype, k);
|
||||
if (loss_scale > 0) {
|
||||
grad /= loss_scale;
|
||||
}
|
||||
float param = scalar_load_offset(_params, param_dtype, k);
|
||||
float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k);
|
||||
float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k);
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad = param * _weight_decay + grad;
|
||||
}
|
||||
momentum = momentum * _betta1;
|
||||
momentum = grad * betta1_minus1 + momentum;
|
||||
|
||||
variance = variance * _betta2;
|
||||
grad = grad * grad;
|
||||
variance = grad * betta2_minus1 + variance;
|
||||
|
||||
grad = sqrt(variance);
|
||||
grad = grad * _bias_correction2 + _eps;
|
||||
grad = momentum / grad;
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param += w_decay * param;
|
||||
}
|
||||
param = grad * step_size + param;
|
||||
|
||||
scalar_store_offset(_params, param_dtype, param, k);
|
||||
scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k);
|
||||
scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg,
|
||||
void *_exp_avg_sq, size_t _param_size,
|
||||
at::ScalarType param_dtype,
|
||||
at::ScalarType grad_dtype,
|
||||
at::ScalarType exp_avg_dtype,
|
||||
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__aarch64__)
|
||||
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
|
||||
#endif
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
|
||||
#if defined(__aarch64__)
|
||||
float32x4_t betta1_4 = simd_set(_betta1);
|
||||
float32x4_t betta2_4 = simd_set(_betta2);
|
||||
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
|
||||
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
|
||||
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
|
||||
float32x4_t eps_4 = simd_set(_eps);
|
||||
float32x4_t step_size_4 = simd_set(step_size);
|
||||
float32x4_t weight_decay_4;
|
||||
if (_weight_decay > 0) {
|
||||
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
|
||||
}
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
|
||||
float32x4_t grad_4[4];
|
||||
float32x4_t momentum_4[4];
|
||||
float32x4_t variance_4[4];
|
||||
float32x4_t param_4[4];
|
||||
#pragma unroll 4
|
||||
for (int j = 0; j < 4; j++) {
|
||||
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
|
||||
if (loss_scale > 0) {
|
||||
float32x4_t loss_scale_vec = simd_set(loss_scale);
|
||||
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
|
||||
}
|
||||
momentum_4[j] =
|
||||
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
|
||||
variance_4[j] =
|
||||
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
|
||||
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
|
||||
}
|
||||
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
|
||||
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
|
||||
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
|
||||
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
|
||||
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
|
||||
grad_4[j] = vsqrtq_f32(variance_4[j]);
|
||||
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
|
||||
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
|
||||
}
|
||||
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
|
||||
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
|
||||
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
|
||||
i + SIMD_WIDTH * j);
|
||||
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
|
||||
i + SIMD_WIDTH * j);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
Step_1(scalar_seek_offset(_params, param_dtype, rounded_size),
|
||||
scalar_seek_offset(grads, grad_dtype, rounded_size),
|
||||
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
|
||||
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
|
||||
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
|
||||
exp_avg_sq_dtype, loss_scale);
|
||||
}
|
||||
}
|
||||
|
||||
void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg,
|
||||
void *_exp_avg_sq, size_t _param_size,
|
||||
at::ScalarType param_dtype,
|
||||
at::ScalarType grad_dtype,
|
||||
at::ScalarType exp_avg_dtype,
|
||||
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__aarch64__)
|
||||
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
|
||||
#endif
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
#if defined(__aarch64__)
|
||||
float32x4_t betta1_4 = simd_set(_betta1);
|
||||
float32x4_t betta2_4 = simd_set(_betta2);
|
||||
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
|
||||
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
|
||||
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
|
||||
float32x4_t eps_4 = simd_set(_eps);
|
||||
float32x4_t step_size_4 = simd_set(step_size);
|
||||
float32x4_t weight_decay_4;
|
||||
if (_weight_decay > 0) {
|
||||
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
|
||||
}
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
|
||||
float32x4_t grad_4[8];
|
||||
float32x4_t momentum_4[8];
|
||||
float32x4_t variance_4[8];
|
||||
float32x4_t param_4[8];
|
||||
#pragma unroll 4
|
||||
for (int j = 0; j < 8; j++) {
|
||||
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
|
||||
if (loss_scale > 0) {
|
||||
float32x4_t loss_scale_vec = simd_set(loss_scale);
|
||||
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
|
||||
}
|
||||
momentum_4[j] =
|
||||
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
|
||||
variance_4[j] =
|
||||
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
|
||||
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
|
||||
}
|
||||
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
|
||||
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
|
||||
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
|
||||
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
|
||||
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
|
||||
grad_4[j] = vsqrtq_f32(variance_4[j]);
|
||||
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
|
||||
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
|
||||
}
|
||||
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
|
||||
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
|
||||
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
|
||||
i + SIMD_WIDTH * j);
|
||||
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
|
||||
i + SIMD_WIDTH * j);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
Step_4(scalar_seek_offset(_params, param_dtype, rounded_size),
|
||||
scalar_seek_offset(grads, grad_dtype, rounded_size),
|
||||
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
|
||||
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
|
||||
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
|
||||
exp_avg_sq_dtype, loss_scale);
|
||||
}
|
||||
}
|
||||
|
||||
void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2,
|
||||
float epsilon, float weight_decay,
|
||||
bool bias_correction, torch::Tensor ¶ms,
|
||||
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||
torch::Tensor &exp_avg_sq, float loss_scale) {
|
||||
auto params_c = params.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
this->IncrementStep(step, beta1, beta2);
|
||||
this->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(),
|
||||
exp_avg_sq_c.data_ptr(), params_c.numel(),
|
||||
params_c.scalar_type(), grads_c.scalar_type(),
|
||||
exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale);
|
||||
}
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
py::class_<AdamOptimizer>(m, "CPUAdamOptimizer")
|
||||
.def(py::init<float, float, float, float, float, bool>())
|
||||
.def("step", &AdamOptimizer::step);
|
||||
}
|
201
extensions/csrc/arm/cpu_adam_arm.h
Normal file
201
extensions/csrc/arm/cpu_adam_arm.h
Normal file
@@ -0,0 +1,201 @@
|
||||
#pragma once
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
||||
#define TILE (128 * 1024 * 1024)
|
||||
|
||||
#if defined(__aarch64__)
|
||||
#include <arm_neon.h>
|
||||
#define SIMD_WIDTH 4
|
||||
|
||||
inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype,
|
||||
size_t offset) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float: {
|
||||
auto ptr_f = reinterpret_cast<const float32_t *>(ptr);
|
||||
return vld1q_f32(ptr_f + offset);
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
auto ptr_h = reinterpret_cast<const float16_t *>(ptr);
|
||||
return vcvt_f32_f16(vld1_f16(ptr_h + offset));
|
||||
}
|
||||
// case at::ScalarType::BFloat16: {
|
||||
// auto ptr_b = reinterpret_cast<const bfloat16_t *>(ptr);
|
||||
// return vcvt_f32_bf16(vld1_bf16(ptr_b + offset));
|
||||
// }
|
||||
default:
|
||||
AT_ERROR("Unsupported dtype");
|
||||
break;
|
||||
}
|
||||
}
|
||||
inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) {
|
||||
return simd_load_offset(ptr, dtype, 0);
|
||||
}
|
||||
|
||||
inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data,
|
||||
size_t offset) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float: {
|
||||
auto ptr_f = reinterpret_cast<float32_t *>(ptr);
|
||||
vst1q_f32(ptr_f + offset, data);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
auto ptr_h = reinterpret_cast<float16_t *>(ptr);
|
||||
vst1_f16(ptr_h + offset, vcvt_f16_f32(data));
|
||||
break;
|
||||
}
|
||||
// case at::ScalarType::BFloat16: {
|
||||
// auto ptr_b = reinterpret_cast<bfloat16_t *>(ptr);
|
||||
// vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data));
|
||||
// break;
|
||||
// }
|
||||
default:
|
||||
AT_ERROR("Unsupported dtype");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) {
|
||||
return simd_store_offset(ptr, dtype, data, 0);
|
||||
}
|
||||
|
||||
inline float32x4_t simd_set(float value) {
|
||||
auto val = static_cast<float32_t>(value);
|
||||
return vdupq_n_f32(val);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
inline float scalar_load_offset(const void *ptr, at::ScalarType dtype,
|
||||
size_t offset) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float:
|
||||
return *(reinterpret_cast<const float *>(ptr) + offset);
|
||||
case at::ScalarType::Half:
|
||||
return static_cast<float>(
|
||||
*(reinterpret_cast<const at::Half *>(ptr) + offset));
|
||||
// case at::ScalarType::BFloat16:
|
||||
// return static_cast<float>(
|
||||
// *(reinterpret_cast<const at::BFloat16 *>(ptr) + offset));
|
||||
default:
|
||||
AT_ERROR("Unsupported dtype");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data,
|
||||
size_t offset) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float:
|
||||
*(reinterpret_cast<float *>(ptr) + offset) = data;
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
*(reinterpret_cast<at::Half *>(ptr) + offset) = data;
|
||||
break;
|
||||
// case at::ScalarType::BFloat16:
|
||||
// *(reinterpret_cast<at::BFloat16 *>(ptr) + offset) = data;
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unsupported dtype");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype,
|
||||
size_t offset) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float:
|
||||
return reinterpret_cast<float *>(ptr) + offset;
|
||||
case at::ScalarType::Half:
|
||||
return reinterpret_cast<at::Half *>(ptr) + offset;
|
||||
// case at::ScalarType::BFloat16:
|
||||
// return reinterpret_cast<at::BFloat16 *>(ptr) + offset;
|
||||
default:
|
||||
AT_ERROR("Unsupported dtype");
|
||||
break;
|
||||
}
|
||||
}
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \
|
||||
void *_exp_avg_sq, size_t _param_size, \
|
||||
at::ScalarType param_dtype, at::ScalarType grad_dtype, \
|
||||
at::ScalarType exp_avg_dtype, \
|
||||
at::ScalarType exp_avg_sq_dtype, float loss_scale = -1);
|
||||
|
||||
class AdamOptimizer {
|
||||
private:
|
||||
float _alpha;
|
||||
float _betta1;
|
||||
float _betta2;
|
||||
float _eps;
|
||||
float _weight_decay;
|
||||
|
||||
float _betta1_t;
|
||||
float _betta2_t;
|
||||
size_t _step;
|
||||
|
||||
float _bias_correction1;
|
||||
float _bias_correction2;
|
||||
|
||||
bool _adamw_mode;
|
||||
|
||||
public:
|
||||
AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
|
||||
float eps = 1e-8, float weight_decay = 0,
|
||||
bool adamw_mode = true)
|
||||
: _alpha(alpha),
|
||||
_betta1(betta1),
|
||||
_betta2(betta2),
|
||||
_eps(eps),
|
||||
_weight_decay(weight_decay),
|
||||
_betta1_t(1.0),
|
||||
_betta2_t(1.0),
|
||||
_step(0),
|
||||
_adamw_mode(adamw_mode) {}
|
||||
~AdamOptimizer() {}
|
||||
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2) {
|
||||
if (beta1 != _betta1 || beta2 != _betta2) {
|
||||
_step = step;
|
||||
_betta1 = beta1;
|
||||
_betta2 = beta2;
|
||||
_betta1_t = std::pow(_betta1, step);
|
||||
_betta2_t = std::pow(_betta2, step);
|
||||
} else {
|
||||
_step++;
|
||||
if (_step != step) {
|
||||
_betta1_t = std::pow(_betta1, step);
|
||||
_betta2_t = std::pow(_betta2, step);
|
||||
_step = step;
|
||||
} else {
|
||||
_betta1_t *= _betta1;
|
||||
_betta2_t *= _betta2;
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void update_state(float lr, float epsilon, float weight_decay,
|
||||
bool bias_correction) {
|
||||
_alpha = lr;
|
||||
_eps = epsilon;
|
||||
_weight_decay = weight_decay;
|
||||
|
||||
_bias_correction1 = 1.0f;
|
||||
_bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1) {
|
||||
_bias_correction1 = 1 - _betta1_t;
|
||||
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
|
||||
}
|
||||
}
|
||||
|
||||
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
|
||||
float weight_decay, bool bias_correction, torch::Tensor ¶ms,
|
||||
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||
torch::Tensor &exp_avg_sq, float loss_scale);
|
||||
};
|
49
extensions/csrc/cuda/colossal_C_frontend.cpp
Normal file
49
extensions/csrc/cuda/colossal_C_frontend.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
|
||||
#include <torch/extension.h>
|
||||
|
||||
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float scale);
|
||||
|
||||
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float wd, float momentum, float dampening, float lr,
|
||||
bool nesterov, bool first_run,
|
||||
bool wd_after_momentum, float scale);
|
||||
|
||||
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr, const float beta1,
|
||||
const float beta2, const float epsilon,
|
||||
const int step, const int mode,
|
||||
const int bias_correction, const float weight_decay,
|
||||
const float div_scale);
|
||||
|
||||
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr, const float beta1,
|
||||
const float beta2, const float epsilon,
|
||||
const int step, const int bias_correction,
|
||||
const float weight_decay, const int grad_averaging,
|
||||
const int mode, at::Tensor global_grad_norm,
|
||||
const float max_grad_norm,
|
||||
at::optional<bool> use_nvlamb_python);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
||||
int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
|
||||
"Fused overflow check + scale for a list of contiguous tensors");
|
||||
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
|
||||
"Fused SGD optimizer for list of contiguous tensors");
|
||||
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
|
||||
"Compute and apply gradient update to parameters for Adam optimizer");
|
||||
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
|
||||
"Computes and apply update for LAMB optimizer");
|
||||
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
|
||||
"Computes L2 norm for a list of contiguous tensors");
|
||||
}
|
10
extensions/csrc/cuda/compat.h
Normal file
10
extensions/csrc/cuda/compat.h
Normal file
@@ -0,0 +1,10 @@
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
||||
#ifdef VERSION_GE_1_3
|
||||
#define DATA_PTR data_ptr
|
||||
#else
|
||||
#define DATA_PTR data
|
||||
#endif
|
446
extensions/csrc/cuda/cpu_adam.cpp
Normal file
446
extensions/csrc/cuda/cpu_adam.cpp
Normal file
@@ -0,0 +1,446 @@
|
||||
/*
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
*/
|
||||
#include "cpu_adam.h"
|
||||
|
||||
#include <math.h>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
|
||||
float *_exp_avg_sq, size_t _param_size,
|
||||
bool param_half_precision, bool grad_half_precision,
|
||||
bool momentum_half_precision,
|
||||
bool variance_half_precision, float loss_scale) {
|
||||
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
|
||||
__half *params_cast_h = reinterpret_cast<__half *>(_params);
|
||||
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
|
||||
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
|
||||
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
AVX_Data betta2_4;
|
||||
betta2_4.data = SIMD_SET(_betta2);
|
||||
|
||||
AVX_Data betta1_minus1_4;
|
||||
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
|
||||
AVX_Data betta2_minus1_4;
|
||||
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
|
||||
|
||||
AVX_Data bias2_sqrt;
|
||||
bias2_sqrt.data = SIMD_SET(_bias_correction2);
|
||||
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
|
||||
AVX_Data step_size_4;
|
||||
step_size_4.data = SIMD_SET(step_size);
|
||||
|
||||
AVX_Data weight_decay_4;
|
||||
if (_weight_decay > 0)
|
||||
weight_decay_4.data =
|
||||
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
|
||||
AVX_Data grad_4;
|
||||
this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4);
|
||||
if (loss_scale > 0) {
|
||||
AVX_Data loss_scale_vec;
|
||||
loss_scale_vec.data = SIMD_SET(loss_scale);
|
||||
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
|
||||
}
|
||||
AVX_Data momentum_4;
|
||||
this->simd_load(momentum_half_precision, _exp_avg + i,
|
||||
momentum_cast_h + i, momentum_4);
|
||||
|
||||
AVX_Data variance_4;
|
||||
this->simd_load(variance_half_precision, _exp_avg_sq + i,
|
||||
variance_cast_h + i, variance_4);
|
||||
|
||||
AVX_Data param_4;
|
||||
this->simd_load(param_half_precision, _params + i, params_cast_h + i,
|
||||
param_4);
|
||||
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data);
|
||||
}
|
||||
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
|
||||
momentum_4.data =
|
||||
SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
|
||||
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
|
||||
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
|
||||
variance_4.data =
|
||||
SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
|
||||
grad_4.data = SIMD_SQRT(variance_4.data);
|
||||
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
|
||||
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
|
||||
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4.data =
|
||||
SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data);
|
||||
}
|
||||
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
|
||||
|
||||
this->simd_store(param_half_precision, _params + i, params_cast_h + i,
|
||||
param_4);
|
||||
this->simd_store(momentum_half_precision, _exp_avg + i,
|
||||
momentum_cast_h + i, momentum_4);
|
||||
this->simd_store(variance_half_precision, _exp_avg_sq + i,
|
||||
variance_cast_h + i, variance_4);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
if (loss_scale > 0) {
|
||||
grad /= loss_scale;
|
||||
}
|
||||
float param =
|
||||
param_half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float momentum =
|
||||
momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k];
|
||||
float variance = variance_half_precision ? (float)variance_cast_h[k]
|
||||
: _exp_avg_sq[k];
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad = param * _weight_decay + grad;
|
||||
}
|
||||
momentum = momentum * _betta1;
|
||||
momentum = grad * betta1_minus1 + momentum;
|
||||
|
||||
variance = variance * _betta2;
|
||||
grad = grad * grad;
|
||||
variance = grad * betta2_minus1 + variance;
|
||||
|
||||
grad = sqrt(variance);
|
||||
grad = grad * _bias_correction2 + _eps;
|
||||
grad = momentum / grad;
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param += w_decay * param;
|
||||
}
|
||||
param = grad * step_size + param;
|
||||
|
||||
if (param_half_precision)
|
||||
params_cast_h[k] = (__half)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
if (momentum_half_precision)
|
||||
momentum_cast_h[k] = (__half)(momentum);
|
||||
else
|
||||
_exp_avg[k] = momentum;
|
||||
if (variance_half_precision)
|
||||
variance_cast_h[k] = (__half)(variance);
|
||||
else
|
||||
_exp_avg_sq[k] = variance;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
|
||||
float *_exp_avg_sq, size_t _param_size,
|
||||
bool param_half_precision, bool grad_half_precision,
|
||||
bool momentum_half_precision,
|
||||
bool variance_half_precision, float loss_scale) {
|
||||
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
|
||||
|
||||
__half *params_cast_h = reinterpret_cast<__half *>(_params);
|
||||
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
|
||||
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
|
||||
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
AVX_Data betta2_4;
|
||||
betta2_4.data = SIMD_SET(_betta2);
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
AVX_Data betta1_minus1_4;
|
||||
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
AVX_Data betta2_minus1_4;
|
||||
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
|
||||
|
||||
AVX_Data bias2_sqrt;
|
||||
bias2_sqrt.data = SIMD_SET(_bias_correction2);
|
||||
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
AVX_Data step_size_4;
|
||||
step_size_4.data = SIMD_SET(step_size);
|
||||
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
AVX_Data weight_decay_4;
|
||||
if (_weight_decay > 0)
|
||||
weight_decay_4.data =
|
||||
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
|
||||
AVX_Data grad_4[4];
|
||||
AVX_Data momentum_4[4];
|
||||
AVX_Data variance_4[4];
|
||||
AVX_Data param_4[4];
|
||||
#pragma unroll 4
|
||||
for (int j = 0; j < 4; j++) {
|
||||
this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
|
||||
grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
|
||||
|
||||
if (loss_scale > 0) {
|
||||
AVX_Data loss_scale_vec;
|
||||
loss_scale_vec.data = SIMD_SET(loss_scale);
|
||||
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
|
||||
}
|
||||
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
|
||||
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
|
||||
this->simd_load(variance_half_precision,
|
||||
_exp_avg_sq + i + SIMD_WIDTH * j,
|
||||
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
|
||||
this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
|
||||
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
|
||||
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4[j].data =
|
||||
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
|
||||
}
|
||||
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
|
||||
momentum_4[j].data =
|
||||
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
|
||||
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
|
||||
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
|
||||
variance_4[j].data =
|
||||
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
|
||||
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
|
||||
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
|
||||
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
|
||||
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4[j].data =
|
||||
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
|
||||
}
|
||||
param_4[j].data =
|
||||
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
|
||||
this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
|
||||
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
|
||||
this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
|
||||
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
|
||||
this->simd_store(variance_half_precision,
|
||||
_exp_avg_sq + i + SIMD_WIDTH * j,
|
||||
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((param_half_precision ? (float *)(params_cast_h + rounded_size)
|
||||
: _params + rounded_size),
|
||||
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
|
||||
: grads + rounded_size),
|
||||
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
|
||||
: _exp_avg + rounded_size),
|
||||
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
|
||||
: _exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size), param_half_precision,
|
||||
grad_half_precision, momentum_half_precision,
|
||||
variance_half_precision, loss_scale);
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
|
||||
float *_exp_avg_sq, size_t _param_size,
|
||||
bool param_half_precision, bool grad_half_precision,
|
||||
bool momentum_half_precision,
|
||||
bool variance_half_precision, float loss_scale) {
|
||||
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
|
||||
__half *params_cast_h = reinterpret_cast<__half *>(_params);
|
||||
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
|
||||
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
|
||||
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
AVX_Data betta2_4;
|
||||
betta2_4.data = SIMD_SET(_betta2);
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
AVX_Data betta1_minus1_4;
|
||||
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
AVX_Data betta2_minus1_4;
|
||||
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
|
||||
|
||||
AVX_Data bias2_sqrt;
|
||||
bias2_sqrt.data = SIMD_SET(_bias_correction2);
|
||||
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
AVX_Data step_size_4;
|
||||
step_size_4.data = SIMD_SET(step_size);
|
||||
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
AVX_Data weight_decay_4;
|
||||
if (_weight_decay > 0)
|
||||
weight_decay_4.data =
|
||||
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
|
||||
AVX_Data grad_4[8];
|
||||
AVX_Data momentum_4[8];
|
||||
AVX_Data variance_4[8];
|
||||
AVX_Data param_4[8];
|
||||
#pragma unroll 8
|
||||
for (int j = 0; j < 8; j++) {
|
||||
this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
|
||||
grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
|
||||
|
||||
if (loss_scale > 0) {
|
||||
AVX_Data loss_scale_vec;
|
||||
loss_scale_vec.data = SIMD_SET(loss_scale);
|
||||
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
|
||||
}
|
||||
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
|
||||
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
|
||||
this->simd_load(variance_half_precision,
|
||||
_exp_avg_sq + i + SIMD_WIDTH * j,
|
||||
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
|
||||
this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
|
||||
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
|
||||
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4[j].data =
|
||||
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
|
||||
}
|
||||
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
|
||||
momentum_4[j].data =
|
||||
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
|
||||
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
|
||||
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
|
||||
variance_4[j].data =
|
||||
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
|
||||
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
|
||||
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
|
||||
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4[j].data =
|
||||
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
|
||||
}
|
||||
param_4[j].data =
|
||||
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
|
||||
|
||||
this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
|
||||
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
|
||||
this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
|
||||
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
|
||||
this->simd_store(variance_half_precision,
|
||||
_exp_avg_sq + i + SIMD_WIDTH * j,
|
||||
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((param_half_precision ? (float *)(params_cast_h + rounded_size)
|
||||
: _params + rounded_size),
|
||||
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
|
||||
: grads + rounded_size),
|
||||
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
|
||||
: _exp_avg + rounded_size),
|
||||
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
|
||||
: _exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size), param_half_precision,
|
||||
grad_half_precision, momentum_half_precision,
|
||||
variance_half_precision, loss_scale);
|
||||
}
|
||||
|
||||
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
|
||||
float epsilon, float weight_decay,
|
||||
bool bias_correction, torch::Tensor ¶ms,
|
||||
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||
torch::Tensor &exp_avg_sq, float loss_scale) {
|
||||
auto params_c = params.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
float *params_ptr = (float *)params_c.data_ptr();
|
||||
float *grads_ptr = (float *)grads_c.data_ptr();
|
||||
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
|
||||
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
|
||||
|
||||
this->IncrementStep(step, beta1, beta2);
|
||||
this->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
|
||||
params_c.numel(), (params.options().dtype() == at::kHalf),
|
||||
(grads.options().dtype() == at::kHalf),
|
||||
(exp_avg.options().dtype() == at::kHalf),
|
||||
(exp_avg_sq.options().dtype() == at::kHalf), loss_scale);
|
||||
}
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
py::class_<Adam_Optimizer>(m, "CPUAdamOptimizer")
|
||||
.def(py::init<float, float, float, float, float, bool>())
|
||||
.def("step", &Adam_Optimizer::step);
|
||||
}
|
185
extensions/csrc/cuda/cpu_adam.h
Normal file
185
extensions/csrc/cuda/cpu_adam.h
Normal file
@@ -0,0 +1,185 @@
|
||||
/*
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <stdio.h>
|
||||
#include <torch/extension.h>
|
||||
#if (__x86_64__ || __i386__)
|
||||
#include <cpuid.h>
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
||||
#define TILE (128 * 1024 * 1024)
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
|
||||
#if defined(__AVX512__)
|
||||
#define SIMD_WIDTH 16
|
||||
#define INTV __m256i
|
||||
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
|
||||
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
|
||||
#define SIMD_SET(x) _mm512_set1_ps(x)
|
||||
#define SIMD_ADD(x, y) _mm512_add_ps(x, y)
|
||||
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
|
||||
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
|
||||
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
|
||||
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
|
||||
#define SIMD_LOAD_HALF(x) \
|
||||
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
|
||||
#define SIMD_STORE_HALF(x, d) \
|
||||
_mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
|
||||
d, _MM_FROUND_TO_NEAREST_INT)))
|
||||
|
||||
#elif defined(__AVX256__) or defined(__AVX2__)
|
||||
#define SIMD_WIDTH 8
|
||||
#define INTV __m128i
|
||||
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
|
||||
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
|
||||
#define SIMD_SET(x) _mm256_set1_ps(x)
|
||||
#define SIMD_ADD(x, y) _mm256_add_ps(x, y)
|
||||
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
|
||||
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
|
||||
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
|
||||
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
|
||||
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
|
||||
#define SIMD_STORE_HALF(x, d) \
|
||||
_mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
|
||||
d, _MM_FROUND_TO_NEAREST_INT)))
|
||||
|
||||
#endif
|
||||
|
||||
union AVX_Data {
|
||||
#if defined(__AVX512__)
|
||||
__m512 data;
|
||||
#elif defined(__AVX256__) or defined(__AVX2__)
|
||||
__m256 data;
|
||||
#endif
|
||||
// float data_f[16];
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN( \
|
||||
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
|
||||
size_t _param_size, bool param_half_precision = false, \
|
||||
bool grad_half_precision = false, bool momentum_half_precision = false, \
|
||||
bool variance_half_precision = false, float loss_scale = -1);
|
||||
|
||||
class Adam_Optimizer {
|
||||
public:
|
||||
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
|
||||
float eps = 1e-8, float weight_decay = 0,
|
||||
bool adamw_mode = true)
|
||||
: _alpha(alpha),
|
||||
_betta1(betta1),
|
||||
_betta2(betta2),
|
||||
_eps(eps),
|
||||
_weight_decay(weight_decay),
|
||||
_betta1_t(1.0),
|
||||
_betta2_t(1.0),
|
||||
_step(0),
|
||||
_adamw_mode(adamw_mode) {}
|
||||
~Adam_Optimizer() {}
|
||||
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2) {
|
||||
if (beta1 != _betta1 || beta2 != _betta2) {
|
||||
_step = step;
|
||||
_betta1 = beta1;
|
||||
_betta2 = beta2;
|
||||
_betta1_t = std::pow(_betta1, step);
|
||||
_betta2_t = std::pow(_betta2, step);
|
||||
} else {
|
||||
_step++;
|
||||
if (_step != step) {
|
||||
_betta1_t = std::pow(_betta1, step);
|
||||
_betta2_t = std::pow(_betta2, step);
|
||||
_step = step;
|
||||
} else {
|
||||
_betta1_t *= _betta1;
|
||||
_betta2_t *= _betta2;
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void update_state(float lr, float epsilon, float weight_decay,
|
||||
bool bias_correction) {
|
||||
_alpha = lr;
|
||||
_eps = epsilon;
|
||||
_weight_decay = weight_decay;
|
||||
|
||||
_bias_correction1 = 1.0f;
|
||||
_bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1) {
|
||||
_bias_correction1 = 1 - _betta1_t;
|
||||
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
|
||||
AVX_Data &data) {
|
||||
if (is_half) {
|
||||
data.data = SIMD_LOAD_HALF(h_ptr);
|
||||
} else {
|
||||
data.data = SIMD_LOAD(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
inline void simd_store(bool is_half, float *ptr, __half *h_ptr,
|
||||
AVX_Data &data) {
|
||||
if (is_half) {
|
||||
SIMD_STORE_HALF(h_ptr, data.data);
|
||||
} else {
|
||||
SIMD_STORE(ptr, data.data);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
|
||||
float weight_decay, bool bias_correction, torch::Tensor ¶ms,
|
||||
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||
torch::Tensor &exp_avg_sq, float loss_scale);
|
||||
|
||||
private:
|
||||
float _alpha;
|
||||
float _betta1;
|
||||
float _betta2;
|
||||
float _eps;
|
||||
float _weight_decay;
|
||||
|
||||
float _betta1_t;
|
||||
float _betta2_t;
|
||||
size_t _step;
|
||||
|
||||
float _bias_correction1;
|
||||
float _bias_correction2;
|
||||
|
||||
bool _adamw_mode;
|
||||
};
|
312
extensions/csrc/cuda/include/block_reduce.h
Normal file
312
extensions/csrc/cuda/include/block_reduce.h
Normal file
@@ -0,0 +1,312 @@
|
||||
/* Copyright 2021 The LightSeq Team
|
||||
Copyright Tencent/TurboTransformers
|
||||
This block_reduce_n is adapted from Tencent/TurboTransformers
|
||||
*/
|
||||
#pragma once
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
enum class ReduceType { kMax = 0, kSum };
|
||||
const unsigned int WARP_REDUCE_MASK = 0xffffffff;
|
||||
const float REDUCE_FLOAT_INF_NEG = -100000000.f;
|
||||
const float REDUCE_FLOAT_INF_POS = 100000000.f;
|
||||
const unsigned int WARP_REDUCE_SIZE = 32;
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T warpReduceSum(T val) {
|
||||
for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE);
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T blockReduceSum(T val) {
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
if (lane == 0) shared[wid] = val;
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f;
|
||||
val = warpReduceSum<T>(val);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <ReduceType Rtype, int Num>
|
||||
__inline__ __device__ void blockReduce(float *pval);
|
||||
|
||||
// use template to make code more concise
|
||||
template <ReduceType Rtype, int Num>
|
||||
__inline__ __device__ void warpReduce(float *pval);
|
||||
|
||||
// static
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kMax, 1>(float *pval) {
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32));
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kMax, 2>(float *pval) {
|
||||
float val0_tmp, val1_tmp;
|
||||
#define WarpReduceMaxOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
*(pval) = max(val0_tmp, *(pval)); \
|
||||
*(pval + 1) = max(val1_tmp, *(pval + 1));
|
||||
|
||||
WarpReduceMaxOneStep(16, 32);
|
||||
WarpReduceMaxOneStep(8, 32);
|
||||
WarpReduceMaxOneStep(4, 32);
|
||||
WarpReduceMaxOneStep(2, 32);
|
||||
WarpReduceMaxOneStep(1, 32);
|
||||
#undef WarpReduceMaxOneStep
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 1>(float *pval) {
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32);
|
||||
}
|
||||
|
||||
/*
|
||||
* Unorll for loop for warpreduce to
|
||||
* imporve instruction issue efficiency
|
||||
* ElemX means there are X numbers to be summed
|
||||
*/
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
|
||||
float val0_tmp, val1_tmp;
|
||||
#define WarpReduceSumOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
*(pval + 0) += val0_tmp; \
|
||||
*(pval + 1) += val1_tmp
|
||||
|
||||
WarpReduceSumOneStep(16, 32);
|
||||
WarpReduceSumOneStep(8, 32);
|
||||
WarpReduceSumOneStep(4, 32);
|
||||
WarpReduceSumOneStep(2, 32);
|
||||
WarpReduceSumOneStep(1, 32);
|
||||
|
||||
#undef WarpReduceSumOneStep
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 4>(float *pval) {
|
||||
float val0_tmp, val1_tmp, val2_tmp, val3_tmp;
|
||||
#define WarpReduceSumOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \
|
||||
val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \
|
||||
*(pval + 0) += val0_tmp; \
|
||||
*(pval + 1) += val1_tmp; \
|
||||
*(pval + 2) += val2_tmp; \
|
||||
*(pval + 3) += val3_tmp
|
||||
|
||||
WarpReduceSumOneStep(16, 32);
|
||||
WarpReduceSumOneStep(8, 32);
|
||||
WarpReduceSumOneStep(4, 32);
|
||||
WarpReduceSumOneStep(2, 32);
|
||||
WarpReduceSumOneStep(1, 32);
|
||||
#undef WarpReduceSumOneStep
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kSum, 1>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = 0.f;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kSum, 2>(float *pval) {
|
||||
const int num = 2;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = 0.f;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kSum, 4>(float *pval) {
|
||||
const int num = 4;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = 0.f;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kMax, 1>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = REDUCE_FLOAT_INF_NEG;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kMax, 2>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = REDUCE_FLOAT_INF_NEG;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kMax, 4>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = REDUCE_FLOAT_INF_NEG;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
}
|
141
extensions/csrc/cuda/layer_norm_cuda.cpp
Normal file
141
extensions/csrc/cuda/layer_norm_cuda.cpp
Normal file
@@ -0,0 +1,141 @@
|
||||
/*This code from NVIDIA apex:
|
||||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "compat.h"
|
||||
|
||||
namespace {
|
||||
|
||||
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
|
||||
int &n2) {
|
||||
int idiff = input.ndimension() - normalized_shape.size();
|
||||
n2 = 1;
|
||||
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
|
||||
assert(input.sizes()[i + idiff] == normalized_shape[i]);
|
||||
n2 *= normalized_shape[i];
|
||||
}
|
||||
n1 = 1;
|
||||
for (int i = 0; i < idiff; ++i) {
|
||||
n1 *= input.sizes()[i];
|
||||
}
|
||||
}
|
||||
|
||||
void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma,
|
||||
at::Tensor beta) {
|
||||
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
|
||||
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
|
||||
}
|
||||
|
||||
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
|
||||
int &n2) {
|
||||
int64_t normalized_ndim = normalized_shape.size();
|
||||
|
||||
if (normalized_ndim < 1) {
|
||||
std::stringstream ss;
|
||||
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
|
||||
<< "containing at least one element, but got normalized_shape="
|
||||
<< normalized_shape;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
auto input_shape = input.sizes();
|
||||
auto input_ndim = input.dim();
|
||||
|
||||
if (input_ndim < normalized_ndim ||
|
||||
!input_shape.slice(input_ndim - normalized_ndim)
|
||||
.equals(normalized_shape)) {
|
||||
std::stringstream ss;
|
||||
ss << "Given normalized_shape=" << normalized_shape
|
||||
<< ", expected input with shape [*";
|
||||
for (auto size : normalized_shape) {
|
||||
ss << ", " << size;
|
||||
}
|
||||
ss << "], but got input of size" << input_shape;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
compute_n1_n2(input, normalized_shape, n1, n2);
|
||||
}
|
||||
|
||||
void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma, at::Tensor beta, int &n1, int &n2) {
|
||||
check_args(input, normalized_shape, n1, n2);
|
||||
check_args(normalized_shape, gamma, beta);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
|
||||
at::Tensor *input, int n1, int n2,
|
||||
at::IntArrayRef normalized_shape, at::Tensor *gamma,
|
||||
at::Tensor *beta, double epsilon);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma, at::Tensor beta,
|
||||
double epsilon) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(gamma);
|
||||
CHECK_INPUT(beta);
|
||||
int n1, n2;
|
||||
check_args(input, normalized_shape, gamma, beta, n1, n2);
|
||||
|
||||
at::Tensor output =
|
||||
at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
|
||||
at::Tensor mean =
|
||||
at::empty({n1}, input.options().dtype(at::ScalarType::Float));
|
||||
at::Tensor invvar = at::empty_like(mean);
|
||||
|
||||
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape,
|
||||
&gamma, &beta, epsilon);
|
||||
|
||||
return {output, mean, invvar};
|
||||
}
|
||||
|
||||
void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
|
||||
at::Tensor *invvar, at::Tensor *input, int n1,
|
||||
int n2, at::IntArrayRef normalized_shape,
|
||||
at::Tensor *gamma, at::Tensor *beta,
|
||||
double epsilon, at::Tensor *grad_input,
|
||||
at::Tensor *grad_gamma, at::Tensor *grad_beta);
|
||||
|
||||
std::vector<at::Tensor> layer_norm_gradient_affine(
|
||||
at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input,
|
||||
at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta,
|
||||
double epsilon) {
|
||||
CHECK_INPUT(dout);
|
||||
CHECK_INPUT(mean);
|
||||
CHECK_INPUT(invvar);
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(gamma);
|
||||
CHECK_INPUT(beta);
|
||||
int n1, n2;
|
||||
check_args(input, normalized_shape, gamma, beta, n1, n2);
|
||||
|
||||
at::Tensor grad_input = at::empty_like(input);
|
||||
at::Tensor grad_gamma = at::empty_like(gamma);
|
||||
at::Tensor grad_beta = at::empty_like(beta);
|
||||
|
||||
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
|
||||
normalized_shape, &gamma, &beta, epsilon,
|
||||
&grad_input, &grad_gamma, &grad_beta);
|
||||
|
||||
return {grad_input, grad_gamma, grad_beta};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
|
||||
m.def("backward_affine", &layer_norm_gradient_affine,
|
||||
"LayerNorm backward (CUDA)");
|
||||
}
|
683
extensions/csrc/cuda/layer_norm_cuda_kernel.cu
Normal file
683
extensions/csrc/cuda/layer_norm_cuda_kernel.cu
Normal file
@@ -0,0 +1,683 @@
|
||||
/*This code from NVIDIA apex:
|
||||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "ATen/ATen.h"
|
||||
#include "ATen/AccumulateType.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include "ATen/cuda/DeviceUtils.cuh"
|
||||
#include "type_shim.h"
|
||||
|
||||
template <typename U>
|
||||
__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {
|
||||
count = count + U(1);
|
||||
U delta = curr - mu;
|
||||
U lmean = mu + delta / count;
|
||||
mu = lmean;
|
||||
U delta2 = curr - lmean;
|
||||
sigma2 = sigma2 + delta * delta2;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB,
|
||||
U& mu, U& sigma2, U& count) {
|
||||
U delta = muB - mu;
|
||||
U nA = count;
|
||||
U nB = countB;
|
||||
count = count + countB;
|
||||
U nX = count;
|
||||
if (nX > U(0)) {
|
||||
nA = nA / nX;
|
||||
nB = nB / nX;
|
||||
mu = nA * mu + nB * muB;
|
||||
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
|
||||
} else {
|
||||
mu = U(0);
|
||||
sigma2 = U(0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1,
|
||||
const int n2, const int i1, U& mu, U& sigma2,
|
||||
U* buf) {
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensor is contiguous
|
||||
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
|
||||
//
|
||||
// compute variance and mean over n2
|
||||
U count = U(0);
|
||||
mu = U(0);
|
||||
sigma2 = U(0);
|
||||
if (i1 < n1) {
|
||||
// one warp normalizes one n1 index,
|
||||
// synchronization is implicit
|
||||
// initialize with standard Welford algorithm
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const T* lvals = vals + i1 * n2;
|
||||
int l = 4 * thrx;
|
||||
for (; l + 3 < n2; l += 4 * numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
U curr = static_cast<U>(lvals[l + k]);
|
||||
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
U curr = static_cast<U>(lvals[l]);
|
||||
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int l = 0; l <= 4; ++l) {
|
||||
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
|
||||
U muB = WARP_SHFL(mu, srcLaneB);
|
||||
U countB = WARP_SHFL(count, srcLaneB);
|
||||
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
|
||||
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
|
||||
}
|
||||
// threadIdx.x == 0 has correct values for each warp
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
U* ubuf = (U*)buf;
|
||||
U* ibuf = (U*)(ubuf + blockDim.y);
|
||||
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.x == 0 && threadIdx.y >= offset &&
|
||||
threadIdx.y < 2 * offset) {
|
||||
const int wrt_y = threadIdx.y - offset;
|
||||
ubuf[2 * wrt_y] = mu;
|
||||
ubuf[2 * wrt_y + 1] = sigma2;
|
||||
ibuf[wrt_y] = count;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.x == 0 && threadIdx.y < offset) {
|
||||
U muB = ubuf[2 * threadIdx.y];
|
||||
U sigma2B = ubuf[2 * threadIdx.y + 1];
|
||||
U countB = ibuf[threadIdx.y];
|
||||
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
ubuf[0] = mu;
|
||||
ubuf[1] = sigma2;
|
||||
}
|
||||
__syncthreads();
|
||||
mu = ubuf[0];
|
||||
sigma2 = ubuf[1] / U(n2);
|
||||
// don't care about final value of count, we know count == n2
|
||||
} else {
|
||||
mu = WARP_SHFL(mu, 0);
|
||||
sigma2 = WARP_SHFL(sigma2 / U(n2), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void cuWelfordMuSigma2(const at::Half* __restrict__ vals,
|
||||
const int n1, const int n2, const int i1,
|
||||
float& mu, float& sigma2, float* buf) {
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensor is contiguous
|
||||
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
|
||||
//
|
||||
// compute variance and mean over n2
|
||||
float count = 0.0f;
|
||||
mu = float(0);
|
||||
sigma2 = float(0);
|
||||
if (i1 < n1) {
|
||||
// one warp normalizes one n1 index,
|
||||
// synchronization is implicit
|
||||
// initialize with standard Welford algorithm
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const at::Half* lvals = vals + i1 * n2;
|
||||
int l = 8 * thrx;
|
||||
if ((((size_t)lvals) & 3) != 0) {
|
||||
// 16 bit alignment
|
||||
// first thread consumes first point
|
||||
if (thrx == 0) {
|
||||
float curr = static_cast<float>(lvals[0]);
|
||||
cuWelfordOnlineSum(curr, mu, sigma2, count);
|
||||
}
|
||||
++l;
|
||||
}
|
||||
// at this point, lvals[l] are 32 bit aligned for all threads.
|
||||
for (; l + 7 < n2; l += 8 * numx) {
|
||||
for (int k = 0; k < 8; k += 2) {
|
||||
float2 curr = __half22float2(*((__half2*)(lvals + l + k)));
|
||||
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
|
||||
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
float curr = static_cast<float>(lvals[l]);
|
||||
cuWelfordOnlineSum(curr, mu, sigma2, count);
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int l = 0; l <= 4; ++l) {
|
||||
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
|
||||
float muB = WARP_SHFL(mu, srcLaneB);
|
||||
float countB = WARP_SHFL(count, srcLaneB);
|
||||
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
|
||||
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
|
||||
}
|
||||
// threadIdx.x == 0 has correct values for each warp
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
float* ubuf = (float*)buf;
|
||||
float* ibuf = (float*)(ubuf + blockDim.y);
|
||||
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.x == 0 && threadIdx.y >= offset &&
|
||||
threadIdx.y < 2 * offset) {
|
||||
const int wrt_y = threadIdx.y - offset;
|
||||
ubuf[2 * wrt_y] = mu;
|
||||
ubuf[2 * wrt_y + 1] = sigma2;
|
||||
ibuf[wrt_y] = count;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.x == 0 && threadIdx.y < offset) {
|
||||
float muB = ubuf[2 * threadIdx.y];
|
||||
float sigma2B = ubuf[2 * threadIdx.y + 1];
|
||||
float countB = ibuf[threadIdx.y];
|
||||
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
ubuf[0] = mu;
|
||||
ubuf[1] = sigma2;
|
||||
}
|
||||
__syncthreads();
|
||||
mu = ubuf[0];
|
||||
sigma2 = ubuf[1] / float(n2);
|
||||
// don't care about final value of count, we know count == n2
|
||||
} else {
|
||||
mu = WARP_SHFL(mu, 0);
|
||||
sigma2 = WARP_SHFL(sigma2 / float(n2), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
U rsqrt(U v) {
|
||||
return U(1) / sqrt(v);
|
||||
}
|
||||
template <>
|
||||
float rsqrt(float v) {
|
||||
return rsqrtf(v);
|
||||
}
|
||||
template <>
|
||||
double rsqrt(double v) {
|
||||
return rsqrt(v);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// This is the un-specialized struct. Note that we prevent instantiation of
|
||||
// this struct by putting an undefined symbol in the function body so it won't
|
||||
// compile.
|
||||
// template <typename T>
|
||||
// struct SharedMemory
|
||||
// {
|
||||
// // Ensure that we won't compile any un-specialized types
|
||||
// __device__ T *getPointer()
|
||||
// {
|
||||
// extern __device__ void error(void);
|
||||
// error();
|
||||
// return NULL;
|
||||
// }
|
||||
// };
|
||||
// https://github.com/NVIDIA/apex/issues/246
|
||||
template <typename T>
|
||||
struct SharedMemory;
|
||||
|
||||
template <>
|
||||
struct SharedMemory<float> {
|
||||
__device__ float* getPointer() {
|
||||
extern __shared__ float s_float[];
|
||||
return s_float;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__global__ void cuApplyLayerNorm(V* __restrict__ output_vals,
|
||||
U* __restrict__ mean, U* __restrict__ invvar,
|
||||
const T* __restrict__ vals, const int n1,
|
||||
const int n2, const U epsilon,
|
||||
const V* __restrict__ gamma,
|
||||
const V* __restrict__ beta) {
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensors are contiguous
|
||||
//
|
||||
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
U mu, sigma2;
|
||||
cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf);
|
||||
const T* lvals = vals + i1 * n2;
|
||||
V* ovals = output_vals + i1 * n2;
|
||||
U c_invvar = rsqrt(sigma2 + epsilon);
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
if (gamma != NULL && beta != NULL) {
|
||||
for (int i = thrx; i < n2; i += numx) {
|
||||
U curr = static_cast<U>(lvals[i]);
|
||||
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = thrx; i < n2; i += numx) {
|
||||
U curr = static_cast<U>(lvals[i]);
|
||||
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
mean[i1] = mu;
|
||||
invvar[i1] = c_invvar;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__device__ void cuLoadWriteStridedInputs(
|
||||
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
|
||||
const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2,
|
||||
const T* input, const V* dout, const int i1_end, const int n2,
|
||||
const U* __restrict__ mean, const U* __restrict__ invvar) {
|
||||
int i1 = i1_block + thr_load_row_off;
|
||||
if (i1 < i1_end) {
|
||||
U curr_mean = mean[i1];
|
||||
U curr_invvar = invvar[i1];
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int i2 = i2_off + k;
|
||||
int load_idx = i1 * n2 + i2;
|
||||
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
|
||||
if (i2 < n2) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
U curr_dout = static_cast<U>(dout[load_idx]);
|
||||
warp_buf1[write_idx] = curr_dout;
|
||||
warp_buf2[write_idx] =
|
||||
curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
} else {
|
||||
warp_buf1[write_idx] = U(0);
|
||||
warp_buf2[write_idx] = U(0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
|
||||
warp_buf1[write_idx] = U(0);
|
||||
warp_buf2[write_idx] = U(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__device__ void cuLoadAddStridedInputs(
|
||||
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
|
||||
const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2,
|
||||
const T* input, const V* dout, const int i1_end, const int n2,
|
||||
const U* __restrict__ mean, const U* __restrict__ invvar) {
|
||||
int i1 = i1_block + thr_load_row_off;
|
||||
if (i1 < i1_end) {
|
||||
U curr_mean = mean[i1];
|
||||
U curr_invvar = invvar[i1];
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int i2 = i2_off + k;
|
||||
int load_idx = i1 * n2 + i2;
|
||||
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
|
||||
if (i2 < n2) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
U curr_dout = static_cast<U>(dout[load_idx]);
|
||||
warp_buf1[write_idx] += curr_dout;
|
||||
warp_buf2[write_idx] +=
|
||||
curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__global__ void cuComputePartGradGammaBeta(
|
||||
const V* __restrict__ dout, const T* __restrict__ input, const int n1,
|
||||
const int n2, const U* __restrict__ mean, const U* __restrict__ invvar,
|
||||
U epsilon, U* part_grad_gamma, U* part_grad_beta) {
|
||||
const int numsegs_n1 =
|
||||
(n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
|
||||
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
|
||||
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
|
||||
const int i1_beg_plus_one =
|
||||
(blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
|
||||
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
|
||||
const int row_stride = blockDim.x + 1;
|
||||
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
|
||||
const int thr_load_row_off =
|
||||
(threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
|
||||
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y *
|
||||
// blockDim.y + (blockDim.y -
|
||||
// 1)*(blockDim.x/blockDim.y) elements
|
||||
U* warp_buf1 = (U*)buf;
|
||||
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
|
||||
// compute partial sums from strided inputs
|
||||
// do this to increase number of loads in flight
|
||||
cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off,
|
||||
row_stride, warp_buf1, warp_buf2, input, dout,
|
||||
i1_end, n2, mean, invvar);
|
||||
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
|
||||
i1_block += blockDim.y * blockDim.y) {
|
||||
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off,
|
||||
row_stride, warp_buf1, warp_buf2, input, dout,
|
||||
i1_end, n2, mean, invvar);
|
||||
}
|
||||
__syncthreads();
|
||||
// inter-warp reductions
|
||||
// sum within each warp
|
||||
U acc1 = U(0);
|
||||
U acc2 = U(0);
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int row1 = threadIdx.y + k * blockDim.y;
|
||||
int idx1 = row1 * row_stride + threadIdx.x;
|
||||
acc1 += warp_buf1[idx1];
|
||||
acc2 += warp_buf2[idx1];
|
||||
}
|
||||
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
|
||||
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
|
||||
__syncthreads();
|
||||
// sum all warps
|
||||
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
|
||||
if (threadIdx.y < offset) {
|
||||
int row1 = threadIdx.y;
|
||||
int row2 = threadIdx.y + offset;
|
||||
int idx1 = row1 * row_stride + threadIdx.x;
|
||||
int idx2 = row2 * row_stride + threadIdx.x;
|
||||
warp_buf1[idx1] += warp_buf1[idx2];
|
||||
warp_buf2[idx1] += warp_buf2[idx2];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (threadIdx.y == 0 && i2 < n2) {
|
||||
int row1 = threadIdx.y;
|
||||
int row2 = threadIdx.y + 1;
|
||||
int idx1 = row1 * row_stride + threadIdx.x;
|
||||
int idx2 = row2 * row_stride + threadIdx.x;
|
||||
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
|
||||
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, typename V>
|
||||
__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
|
||||
const U* part_grad_beta,
|
||||
const int part_size, const int n1,
|
||||
const int n2, V* grad_gamma,
|
||||
V* grad_beta) {
|
||||
// sum partial gradients for gamma and beta
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i2 < n2) {
|
||||
// each warp does sequential reductions until reduced part_size is num_warps
|
||||
int num_warp_reductions = part_size / blockDim.y;
|
||||
U sum_gamma = U(0);
|
||||
U sum_beta = U(0);
|
||||
const U* part_grad_gamma_ptr =
|
||||
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
|
||||
const U* part_grad_beta_ptr =
|
||||
part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
|
||||
for (int warp_offset = 0; warp_offset < num_warp_reductions;
|
||||
++warp_offset) {
|
||||
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
|
||||
sum_beta += part_grad_beta_ptr[warp_offset * n2];
|
||||
}
|
||||
// inter-warp reductions
|
||||
const int nbsize3 = blockDim.x * blockDim.y / 2;
|
||||
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
|
||||
// top half write to shared memory
|
||||
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
|
||||
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
|
||||
buf[write_idx] = sum_gamma;
|
||||
buf[write_idx + nbsize3] = sum_beta;
|
||||
}
|
||||
__syncthreads();
|
||||
// bottom half sums
|
||||
if (threadIdx.y < offset) {
|
||||
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
sum_gamma += buf[read_idx];
|
||||
sum_beta += buf[read_idx + nbsize3];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// write out fully summed gradients
|
||||
if (threadIdx.y == 0) {
|
||||
grad_gamma[i2] = sum_gamma;
|
||||
grad_beta[i2] = sum_beta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__global__ void cuComputeGradInput(const V* __restrict__ dout,
|
||||
const T* __restrict__ input, const int n1,
|
||||
const int n2, const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar, U epsilon,
|
||||
const V* gamma, T* grad_input) {
|
||||
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
U sum_loss1 = U(0);
|
||||
U sum_loss2 = U(0);
|
||||
const U c_mean = mean[i1];
|
||||
const U c_invvar = invvar[i1];
|
||||
const T* k_input = input + i1 * n2;
|
||||
const V* k_dout = dout + i1 * n2;
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
if (gamma != NULL) {
|
||||
int l = 4 * thrx;
|
||||
for (; l + 3 < n2; l += 4 * numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const U c_h = static_cast<U>(k_input[l + k]);
|
||||
const U c_loss = static_cast<U>(k_dout[l + k]);
|
||||
sum_loss1 += c_loss * gamma[l + k];
|
||||
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
sum_loss1 += c_loss * gamma[l];
|
||||
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
} else {
|
||||
int l = 4 * thrx;
|
||||
for (; l + 3 < n2; l += 4 * numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const U c_h = static_cast<U>(k_input[l + k]);
|
||||
const U c_loss = static_cast<U>(k_dout[l + k]);
|
||||
sum_loss1 += c_loss;
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
sum_loss1 += c_loss;
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
|
||||
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
|
||||
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
|
||||
}
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
|
||||
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
|
||||
buf[2 * wrt_i] = sum_loss1;
|
||||
buf[2 * wrt_i + 1] = sum_loss2;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.y < offset) {
|
||||
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
sum_loss1 += buf[2 * read_i];
|
||||
sum_loss2 += buf[2 * read_i + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (threadIdx.y == 0) {
|
||||
buf[2 * threadIdx.x] = sum_loss1;
|
||||
buf[2 * threadIdx.x + 1] = sum_loss2;
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.y != 0) {
|
||||
sum_loss1 = buf[2 * threadIdx.x];
|
||||
sum_loss2 = buf[2 * threadIdx.x + 1];
|
||||
}
|
||||
}
|
||||
// all threads now have the two sums over l
|
||||
U fH = (U)n2;
|
||||
U term1 = (U(1) / fH) * c_invvar;
|
||||
T* k_grad_input = grad_input + i1 * n2;
|
||||
if (gamma != NULL) {
|
||||
for (int l = thrx; l < n2; l += numx) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss * gamma[l];
|
||||
f_grad_input -= sum_loss1;
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
f_grad_input *= term1;
|
||||
k_grad_input[l] = static_cast<T>(f_grad_input);
|
||||
}
|
||||
} else {
|
||||
for (int l = thrx; l < n2; l += numx) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss;
|
||||
f_grad_input -= sum_loss1;
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
f_grad_input *= term1;
|
||||
k_grad_input[l] = static_cast<T>(f_grad_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void HostApplyLayerNorm(V* output, U* mean, U* invvar, const T* input, int n1,
|
||||
int n2, double epsilon, const V* gamma, const V* beta) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const dim3 threads(32, 4, 1);
|
||||
const uint64_t maxGridY =
|
||||
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
|
||||
int nshared =
|
||||
threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;
|
||||
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
|
||||
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
|
||||
}
|
||||
|
||||
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar,
|
||||
at::Tensor* input, int n1, int n2,
|
||||
#ifdef VERSION_GE_1_1
|
||||
at::IntArrayRef normalized_shape,
|
||||
#else
|
||||
at::IntList normalized_shape,
|
||||
#endif
|
||||
at::Tensor* gamma, at::Tensor* beta, double epsilon) {
|
||||
using namespace at;
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
||||
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
|
||||
HostApplyLayerNorm(output->DATA_PTR<scalar_t_out>(),
|
||||
mean->DATA_PTR<float>(), invvar->DATA_PTR<float>(),
|
||||
input->DATA_PTR<scalar_t_in>(), n1, n2, epsilon,
|
||||
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);)
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar,
|
||||
at::Tensor* input, int n1, int n2, const V* gamma,
|
||||
const V* beta, double epsilon, T* grad_input,
|
||||
V* grad_gamma, V* grad_beta) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
if (gamma != NULL && beta != NULL) {
|
||||
// compute grad_gamma(j) and grad_beta(j)
|
||||
const int part_size = 16;
|
||||
const dim3 threads2(32, 4, 1);
|
||||
const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);
|
||||
const int nshared2_a =
|
||||
2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
|
||||
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
|
||||
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
|
||||
at::Tensor part_grad_gamma = at::empty(
|
||||
{part_size, n2}, input->options().dtype(at::ScalarType::Float));
|
||||
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
|
||||
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
|
||||
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon),
|
||||
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>());
|
||||
|
||||
const dim3 threads3(32, 8, 1);
|
||||
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
|
||||
const int nshared3 = threads3.x * threads3.y * sizeof(U);
|
||||
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
|
||||
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>(), part_size,
|
||||
n1, n2, grad_gamma, grad_beta);
|
||||
}
|
||||
|
||||
// compute grad_input
|
||||
const uint64_t maxGridY =
|
||||
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
|
||||
const dim3 threads1(32, 4, 1);
|
||||
int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
|
||||
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
|
||||
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon), gamma,
|
||||
grad_input);
|
||||
}
|
||||
|
||||
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean,
|
||||
at::Tensor* invvar, at::Tensor* input, int n1,
|
||||
int n2,
|
||||
#ifdef VERSION_GE_1_1
|
||||
at::IntArrayRef normalized_shape,
|
||||
#else
|
||||
at::IntList normalized_shape,
|
||||
#endif
|
||||
at::Tensor* gamma, at::Tensor* beta,
|
||||
double epsilon, at::Tensor* grad_input,
|
||||
at::Tensor* grad_gamma, at::Tensor* grad_beta) {
|
||||
using namespace at;
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
||||
input->scalar_type(), gamma->scalar_type(),
|
||||
"cuda_layer_norm_gradient_kernel",
|
||||
HostLayerNormGradient(
|
||||
dout->DATA_PTR<scalar_t_out>(), mean->DATA_PTR<float>(),
|
||||
invvar->DATA_PTR<float>(), input, n1, n2,
|
||||
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
|
||||
// if gamma Tensor is NULL on input.
|
||||
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL, epsilon,
|
||||
grad_input->DATA_PTR<scalar_t_in>(),
|
||||
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);)
|
||||
}
|
97
extensions/csrc/cuda/moe_cuda.cpp
Normal file
97
extensions/csrc/cuda/moe_cuda.cpp
Normal file
@@ -0,0 +1,97 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
|
||||
torch::Tensor batch_tokens,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
|
||||
torch::Tensor expert_grad,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_cuda_backward(
|
||||
int s, int e, int c, int h, torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
|
||||
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor moe_dispatch_forward(int s, int ec, int h,
|
||||
torch::Tensor batch_tokens,
|
||||
torch::Tensor mask, torch::Tensor dest_idx) {
|
||||
CHECK_INPUT(batch_tokens);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
||||
return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx);
|
||||
}
|
||||
|
||||
torch::Tensor moe_dispatch_backward(int s, int ec, int h,
|
||||
torch::Tensor expert_grad,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
CHECK_INPUT(expert_grad);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
||||
return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx);
|
||||
}
|
||||
|
||||
torch::Tensor moe_combine_forward(int s, int e, int c, int h,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
CHECK_INPUT(expert_tokens);
|
||||
CHECK_INPUT(logits);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
||||
return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask,
|
||||
dest_idx);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_backward(int s, int e, int c, int h,
|
||||
torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
CHECK_INPUT(tokens_grad);
|
||||
CHECK_INPUT(logits);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
||||
return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens,
|
||||
logits, mask, dest_idx);
|
||||
}
|
||||
|
||||
torch::Tensor moe_cumsum(torch::Tensor mask) {
|
||||
CHECK_INPUT(mask);
|
||||
return cumsum_sub_one_in_dim0(mask);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0");
|
||||
m.def("dispatch_forward", &moe_dispatch_forward,
|
||||
"Forward operation in MoE dispatch function");
|
||||
m.def("dispatch_backward", &moe_dispatch_backward,
|
||||
"Backward operation in MoE dispatch function");
|
||||
m.def("combine_forward", &moe_combine_forward,
|
||||
"Combine operation in MoE combine function");
|
||||
m.def("combine_backward", &moe_combine_backward,
|
||||
"Combine operation in MoE combine function");
|
||||
}
|
659
extensions/csrc/cuda/moe_cuda_kernel.cu
Normal file
659
extensions/csrc/cuda/moe_cuda_kernel.cu
Normal file
@@ -0,0 +1,659 @@
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include "block_reduce.h"
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row + idx, pack);
|
||||
BlockStore(ts_store).Store(dst_row + idx, pack);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row + idx, pack);
|
||||
BlockStore(ts_store).Store(src_row + idx, pack);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row + idx, pack);
|
||||
BlockStore(ts_store).Store(dst_row1 + idx, pack);
|
||||
BlockStore(ts_store).Store(dst_row2 + idx, pack);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack1[pack_size], pack2[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row1 + idx, pack1);
|
||||
BlockLoad(ts_load).Load(dst_row2 + idx, pack2);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
pack1[i] += pack2[i];
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(src_row + idx, pack1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
|
||||
const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row + idx, pack);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
pack[i] *= weight;
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(dst_row + idx, pack);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
|
||||
T *weight_grad, const T weight, const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T grad[pack_size], tokens[pack_size];
|
||||
float thread_sum = 0;
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row + idx, grad);
|
||||
BlockLoad(ts_load).Load(tks_row + idx, tokens);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
thread_sum += grad[i] * tokens[i];
|
||||
grad[i] *= weight;
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(src_row + idx, grad);
|
||||
}
|
||||
|
||||
blockReduce<ReduceType::kSum, 1>(&thread_sum);
|
||||
|
||||
if (threadIdx.x == 0) *weight_grad = static_cast<T>(thread_sum);
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
|
||||
const T weight1, const T weight2,
|
||||
const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack1[pack_size], pack2[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row1 + idx, pack1);
|
||||
BlockLoad(ts_load).Load(src_row2 + idx, pack2);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
pack1[i] = pack1[i] * weight1 + pack2[i] * weight2;
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(dst_row + idx, pack1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
|
||||
T *tks_row1, T *tks_row2, T *weight_grad1,
|
||||
T *weight_grad2, const T weight1,
|
||||
const T weight2, const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size],
|
||||
sgrad2[pack_size];
|
||||
float thread_sum[2] = {0, 0};
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row + idx, grad);
|
||||
BlockLoad(ts_load).Load(tks_row1 + idx, tokens1);
|
||||
BlockLoad(ts_load).Load(tks_row2 + idx, tokens2);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
thread_sum[0] += grad[i] * tokens1[i];
|
||||
thread_sum[1] += grad[i] * tokens2[i];
|
||||
sgrad1[i] = weight1 * grad[i];
|
||||
sgrad2[i] = weight2 * grad[i];
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(src_row1 + idx, sgrad1);
|
||||
BlockStore(ts_store).Store(src_row2 + idx, sgrad2);
|
||||
}
|
||||
|
||||
blockReduce<ReduceType::kSum, 2>(thread_sum);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
*weight_grad1 = static_cast<T>(thread_sum[0]);
|
||||
else if (threadIdx.x == 1)
|
||||
*weight_grad2 = static_cast<T>(thread_sum[1]);
|
||||
}
|
||||
|
||||
// DISPATCH KERNELS --------------------------------
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols, const int indicator1,
|
||||
const int indicator2) {
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
|
||||
cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols, const int indicator1,
|
||||
const int indicator2) {
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
|
||||
cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
|
||||
int *mask1, int *mask2, int *dest1,
|
||||
int *dest2, const int h) {
|
||||
int row = blockIdx.x;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
moe_dpch_fwd_selector<T, block_size, pack_size>(
|
||||
batch_tokens + (row * h), expert_input + (dest1[row] * h),
|
||||
expert_input + (dest2[row] * h), h, mask1[row], indicator2);
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
|
||||
int *mask2, int *dest1, int *dest2,
|
||||
const int h) {
|
||||
int row = blockIdx.x;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
moe_dpch_bwd_selector<T, block_size, pack_size>(
|
||||
tokens_grad + (row * h), expert_grad + (dest1[row] * h),
|
||||
expert_grad + (dest2[row] * h), h, mask1[row], indicator2);
|
||||
}
|
||||
|
||||
// COMBINE KERNELS --------------------------------
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
|
||||
const int cols, const T weight1,
|
||||
const T weight2, const int indicator1,
|
||||
const int indicator2) {
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
|
||||
weight1, weight2, cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_cb_one_fwd<T, block_size, pack_size>(src_row1, dst_row, weight1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_cb_one_fwd<T, block_size, pack_size>(src_row2, dst_row, weight2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
|
||||
const int cols, T *tks_row1, T *tks_row2,
|
||||
T *wt_grad1, T *wt_grad2, const T weight1,
|
||||
const T weight2, const int indicator1,
|
||||
const int indicator2) {
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
|
||||
tks_row1, tks_row2, wt_grad1,
|
||||
wt_grad2, weight1, weight2, cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_cb_one_bwd<T, block_size, pack_size>(src_row1, dst_row, tks_row1,
|
||||
wt_grad1, weight1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_cb_one_bwd<T, block_size, pack_size>(src_row2, dst_row, tks_row2,
|
||||
wt_grad2, weight2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
|
||||
T *logits, int *mask1, int *mask2, int *dest1,
|
||||
int *dest2, const int e, const int c,
|
||||
const int h) {
|
||||
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
T *row_log = logits + (row * e);
|
||||
moe_cb_fwd_selector<T, block_size, pack_size>(
|
||||
expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h),
|
||||
combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row],
|
||||
indicator2);
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
|
||||
T *logits, T *logits_grad, int *mask1,
|
||||
int *mask2, int *dest1, int *dest2,
|
||||
const int e, const int c, const int h) {
|
||||
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
|
||||
moe_cb_bwd_selector<T, block_size, pack_size>(
|
||||
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
|
||||
tokens_grad + (row * h), h, tks + (dest1[row] * h),
|
||||
tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1],
|
||||
row_log[eid2], mask1[row], indicator2);
|
||||
}
|
||||
|
||||
// CUMSUM KERNEL --------------------------------
|
||||
|
||||
template <int block_size, int pack_size>
|
||||
__global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
|
||||
const int e) {
|
||||
assert(s % pack_size == 0);
|
||||
constexpr int bpack_size = block_size * pack_size;
|
||||
int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
|
||||
__shared__ int temp[block_size + 1];
|
||||
int pack[pack_size];
|
||||
|
||||
for (int idx = 0; idx < s; idx += bpack_size) {
|
||||
int offset = 1;
|
||||
|
||||
if (idx + tps < s) {
|
||||
temp[tid] = inputs[tps * e + bid];
|
||||
#pragma unroll
|
||||
for (int i = 1; i < pack_size; ++i) {
|
||||
pack[i] = inputs[(tps + i) * e + bid];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 1; i < pack_size; ++i) {
|
||||
temp[tid] += pack[i];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = block_size >> 1; i > 0; i >>= 1) {
|
||||
__syncthreads();
|
||||
if (tid < i) {
|
||||
int j = offset * (2 * tid + 1) - 1;
|
||||
temp[j + offset] += temp[j];
|
||||
}
|
||||
offset <<= 1;
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
temp[block_size] = temp[block_size - 1];
|
||||
temp[block_size - 1] = 0;
|
||||
}
|
||||
|
||||
for (int i = 1; i < block_size; i <<= 1) {
|
||||
offset >>= 1;
|
||||
__syncthreads();
|
||||
if (tid < i) {
|
||||
int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j];
|
||||
temp[j] = temp[k];
|
||||
temp[k] += ts;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) temp[0] = temp[block_size];
|
||||
__syncthreads();
|
||||
|
||||
if (idx + tps < s) {
|
||||
temp[tid + 1] += last_sum;
|
||||
#pragma unroll
|
||||
for (int i = pack_size - 1; i > 0; --i) {
|
||||
outputs[(tps + i) * e + bid] = temp[tid + 1];
|
||||
temp[tid + 1] -= pack[i];
|
||||
}
|
||||
outputs[tps * e + bid] = temp[tid + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
last_sum += temp[0];
|
||||
inputs += bpack_size * e;
|
||||
outputs += bpack_size * e;
|
||||
}
|
||||
}
|
||||
|
||||
// LAUNCH FUNCTIONS --------------------------------
|
||||
|
||||
template <typename T>
|
||||
void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
|
||||
int *mask2, int *dest1, int *dest2, const int s,
|
||||
const int h) {
|
||||
if (h < 256)
|
||||
moe_dpch_fwd_kernel<T, 32, 4>
|
||||
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 512)
|
||||
moe_dpch_fwd_kernel<T, 32, 8>
|
||||
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 1024)
|
||||
moe_dpch_fwd_kernel<T, 32, 16>
|
||||
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 2048)
|
||||
moe_dpch_fwd_kernel<T, 64, 16>
|
||||
<<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else
|
||||
moe_dpch_fwd_kernel<T, 128, 16>
|
||||
<<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
|
||||
int *dest1, int *dest2, const int s, const int h) {
|
||||
if (h < 256)
|
||||
moe_dpch_bwd_kernel<T, 32, 4>
|
||||
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 512)
|
||||
moe_dpch_bwd_kernel<T, 32, 8>
|
||||
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 1024)
|
||||
moe_dpch_bwd_kernel<T, 32, 16>
|
||||
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 2048)
|
||||
moe_dpch_bwd_kernel<T, 64, 16>
|
||||
<<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else
|
||||
moe_dpch_bwd_kernel<T, 128, 16>
|
||||
<<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
|
||||
int *mask1, int *mask2, int *dest1, int *dest2,
|
||||
const int s, const int e, const int c, const int h) {
|
||||
if (h < 256)
|
||||
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
e, c, h);
|
||||
else if (h < 512)
|
||||
moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
e, c, h);
|
||||
else if (h < 1024)
|
||||
moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
e, c, h);
|
||||
else if (h < 2048)
|
||||
moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
e, c, h);
|
||||
else
|
||||
moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1,
|
||||
dest2, e, c, h);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
|
||||
T *logits_grad, int *mask1, int *mask2, int *dest1,
|
||||
int *dest2, const int s, const int e, const int c,
|
||||
const int h) {
|
||||
if (h < 256)
|
||||
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
|
||||
logits, logits_grad, mask1, mask2,
|
||||
dest1, dest2, e, c, h);
|
||||
else // if (h < 512)
|
||||
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks,
|
||||
logits, logits_grad, mask1, mask2,
|
||||
dest1, dest2, e, c, h);
|
||||
// else if (h < 1024)
|
||||
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
|
||||
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
|
||||
// dest1, dest2, e, c, h);
|
||||
// else
|
||||
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
|
||||
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
|
||||
// dest1, dest2, e, c, h);
|
||||
}
|
||||
|
||||
void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
|
||||
if (s <= 256)
|
||||
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
|
||||
else if (s <= 512)
|
||||
cumsum_kernel<512, 1><<<e, 512>>>(inputs, outputs, s, e);
|
||||
else if (s <= 1024)
|
||||
cumsum_kernel<1024, 1><<<e, 1024>>>(inputs, outputs, s, e);
|
||||
else if (s <= 2048)
|
||||
cumsum_kernel<1024, 2><<<e, 1024>>>(inputs, outputs, s, e);
|
||||
else
|
||||
cumsum_kernel<1024, 4><<<e, 1024>>>(inputs, outputs, s, e);
|
||||
}
|
||||
|
||||
// API FUNCTIONS --------------------------------
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented yet for specific data type."); \
|
||||
}
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
|
||||
torch::Tensor batch_tokens,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
assert(h % 16 == 0);
|
||||
auto res = torch::zeros(
|
||||
{ec, h},
|
||||
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
batch_tokens.scalar_type(), "moe dispatch forward",
|
||||
moe_dpch_fwd_launch<scalar_t>(
|
||||
batch_tokens.data<scalar_t>(), res.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
|
||||
torch::Tensor expert_grad,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
assert(h % 16 == 0);
|
||||
auto res = torch::zeros(
|
||||
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
expert_grad.scalar_type(), "moe dispatch backward",
|
||||
moe_dpch_bwd_launch<scalar_t>(
|
||||
res.data<scalar_t>(), expert_grad.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
assert(h % 16 == 0);
|
||||
assert(expert_tokens.dtype() == logits.dtype());
|
||||
|
||||
auto res = torch::zeros(
|
||||
{s, h},
|
||||
torch::dtype(expert_tokens.dtype()).device(expert_tokens.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
expert_tokens.scalar_type(), "moe combine forward",
|
||||
moe_cb_fwd_launch<scalar_t>(
|
||||
expert_tokens.data<scalar_t>(), res.data<scalar_t>(),
|
||||
logits.data<scalar_t>(), mask[0].data<int>(),
|
||||
k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
|
||||
h));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_cuda_backward(
|
||||
int s, int e, int c, int h, torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
assert(h % 16 == 0);
|
||||
assert(tokens_grad.dtype() == expert_tokens.dtype());
|
||||
assert(expert_tokens.dtype() == logits.dtype());
|
||||
|
||||
auto egrad = torch::zeros(
|
||||
{e * c, h},
|
||||
torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())),
|
||||
wgrad = torch::zeros(
|
||||
{s, e}, torch::dtype(logits.dtype()).device(logits.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tokens_grad.scalar_type(), "moe combine backward",
|
||||
moe_cb_bwd_launch<scalar_t>(
|
||||
tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(),
|
||||
expert_tokens.data<scalar_t>(), logits.data<scalar_t>(),
|
||||
wgrad.data<scalar_t>(), mask[0].data<int>(),
|
||||
k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
|
||||
h));
|
||||
|
||||
return {egrad, wgrad};
|
||||
}
|
||||
|
||||
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
|
||||
assert(mask.dim() == 2);
|
||||
assert(mask.dtype() == torch::kInt32);
|
||||
|
||||
const int s = mask.size(0), e = mask.size(1);
|
||||
auto res =
|
||||
torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device()));
|
||||
cumsum_launch(mask.data<int>(), res.data<int>(), s, e);
|
||||
|
||||
return res;
|
||||
}
|
146
extensions/csrc/cuda/multi_tensor_adam.cu
Normal file
146
extensions/csrc/cuda/multi_tensor_adam.cu
Normal file
@@ -0,0 +1,146 @@
|
||||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
|
||||
/* Copyright 2020 The Microsoft DeepSpeed Team
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "type_shim.h"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
typedef enum {
|
||||
ADAM_MODE_0 = 0, // L2 regularization mode
|
||||
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
|
||||
} adamMode_t;
|
||||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T_g, typename T_p>
|
||||
struct AdamFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
|
||||
const float beta1, const float beta2, const float beta1_correction,
|
||||
const float beta2_correction, const float epsilon, const float lr,
|
||||
adamMode_t mode, const float decay, const float div_scale) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
|
||||
// potentially use to pass in list of scalar
|
||||
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
|
||||
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T_g *g = (T_g *)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
||||
T_p *p = (T_p *)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
T_p *m = (T_p *)tl.addresses[2][tensor_loc];
|
||||
m += chunk_idx * chunk_size;
|
||||
|
||||
T_p *v = (T_p *)tl.addresses[3][tensor_loc];
|
||||
v += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// see note in multi_tensor_scale_kernel.cu
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
MATH_T r_v[ILP];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
r_g[ii] = g[i];
|
||||
r_p[ii] = p[i];
|
||||
r_m[ii] = m[i];
|
||||
r_v[ii] = v[i];
|
||||
} else {
|
||||
r_g[ii] = MATH_T(0);
|
||||
r_p[ii] = MATH_T(0);
|
||||
r_m[ii] = MATH_T(0);
|
||||
r_v[ii] = MATH_T(0);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
if (div_scale > 0) r_g[ii] /= div_scale;
|
||||
|
||||
if (mode == ADAM_MODE_0) { // L2
|
||||
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
MATH_T update = next_m_unbiased / denom;
|
||||
r_p[ii] = r_p[ii] - (lr * update);
|
||||
} else { // weight decay
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
|
||||
r_p[ii] = r_p[ii] - (lr * update);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
p[i] = r_p[ii];
|
||||
m[i] = r_m[ii];
|
||||
v[i] = r_v[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr, const float beta1,
|
||||
const float beta2, const float epsilon,
|
||||
const int step, const int mode,
|
||||
const int bias_correction, const float weight_decay,
|
||||
const float div_scale) {
|
||||
using namespace at;
|
||||
|
||||
// Handle bias correction mode
|
||||
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1) {
|
||||
bias_correction1 = 1 - std::pow(beta1, step);
|
||||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
}
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF_FOR_G_P(
|
||||
tensor_lists[0][0].scalar_type(), tensor_lists[1][0].scalar_type(), 0,
|
||||
"adam",
|
||||
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
|
||||
beta2, bias_correction1, bias_correction2, epsilon,
|
||||
lr, (adamMode_t)mode, weight_decay, div_scale);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
130
extensions/csrc/cuda/multi_tensor_apply.cuh
Normal file
130
extensions/csrc/cuda/multi_tensor_apply.cuh
Normal file
@@ -0,0 +1,130 @@
|
||||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
|
||||
/* Copyright 2020 The Microsoft DeepSpeed Team
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <assert.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "compat.h"
|
||||
|
||||
// #include <iostream>
|
||||
|
||||
// This header is the one-stop shop for all your multi-tensor apply needs.
|
||||
|
||||
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
|
||||
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
||||
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
||||
|
||||
template <int n>
|
||||
struct TensorListMetadata {
|
||||
void *addresses[n][depth_to_max_tensors[n - 1]];
|
||||
int sizes[depth_to_max_tensors[n - 1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a
|
||||
// full int.
|
||||
int start_tensor_this_launch;
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename... ArgTypes>
|
||||
__global__ void multi_tensor_apply_kernel(int chunk_size,
|
||||
volatile int *noop_flag, T tl,
|
||||
U callable, ArgTypes... args) {
|
||||
// Hand the chunk information to the user-supplied functor to process however
|
||||
// it likes.
|
||||
callable(chunk_size, noop_flag, tl, args...);
|
||||
}
|
||||
|
||||
template <int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(
|
||||
int block_size, int chunk_size, const at::Tensor &noop_flag,
|
||||
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
|
||||
ArgTypes... args) {
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
|
||||
int len0 = tensor_lists[0].size();
|
||||
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
|
||||
auto ref_device = tensor_lists[0][0].device();
|
||||
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
|
||||
for (int l = 0; l < tensor_lists.size();
|
||||
l++) // No range-based for because I need indices
|
||||
{
|
||||
TORCH_CHECK(tensor_lists[l].size() == len0,
|
||||
"Size mismatch among tensor lists");
|
||||
for (int t = 0; t < tensor_lists[l].size(); t++) {
|
||||
// TODO: Print which tensor fails.
|
||||
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
|
||||
#ifdef VERSION_GE_1_5
|
||||
contiguous_memory =
|
||||
(contiguous_memory ||
|
||||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
|
||||
#endif
|
||||
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
|
||||
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
|
||||
"A tensor was not on the same device as the first tensor");
|
||||
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(),
|
||||
"Size mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
|
||||
TensorListMetadata<depth> tl;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
tl.start_tensor_this_launch = 0;
|
||||
int loc_block_info = 0;
|
||||
int loc_tensor_info = 0;
|
||||
for (int t = 0; t < ntensors; t++) {
|
||||
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
|
||||
loc_tensor_info++;
|
||||
|
||||
int chunks_this_tensor =
|
||||
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
|
||||
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
|
||||
// std::cout << chunks_this_tensor << std::endl;
|
||||
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tl.block_to_chunk[loc_block_info] = chunk;
|
||||
loc_block_info++;
|
||||
|
||||
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
||||
chunk == chunks_this_tensor - 1);
|
||||
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
|
||||
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
|
||||
if (tensors_full || blocks_full || last_chunk) {
|
||||
// using accscalar_t = acc_type<scalar_t, true>;
|
||||
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
|
||||
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Reset. The control flow possibilities here make my brain hurt.
|
||||
loc_block_info = 0;
|
||||
if (chunk == chunks_this_tensor - 1) {
|
||||
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3
|
||||
// << std::endl;
|
||||
loc_tensor_info = 0;
|
||||
tl.start_tensor_this_launch = t + 1;
|
||||
} else {
|
||||
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3
|
||||
// << std::endl;
|
||||
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
|
||||
loc_tensor_info = 1;
|
||||
tl.start_tensor_this_launch = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
382
extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu
Normal file
382
extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu
Normal file
@@ -0,0 +1,382 @@
|
||||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_l2norm_kernel.cu
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "type_shim.h"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p) {
|
||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
|
||||
int src_offset) {
|
||||
typedef
|
||||
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
|
||||
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
||||
}
|
||||
|
||||
template <typename x_t>
|
||||
struct L2NormFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
|
||||
float *output, float *output_per_tensor, bool per_tensor,
|
||||
int max_chunks_per_tensor) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
x_t *x = (x_t *)tl.addresses[0][tensor_loc];
|
||||
x += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
__shared__ float s_vals[512];
|
||||
|
||||
float vals[ILP]; // = {0}; // this probably works too but I want to be
|
||||
// sure...
|
||||
x_t r_x[ILP];
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
vals[i] = 0.f;
|
||||
r_x[i] = 0;
|
||||
}
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
|
||||
for (int i_start = threadIdx.x;
|
||||
i_start * ILP < n && i_start * ILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_x, x, 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
float next = static_cast<float>(r_x[ii]);
|
||||
vals[ii] += next * next;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
float next = static_cast<float>(x[i]);
|
||||
vals[ii] += next * next;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float val = 0.f;
|
||||
for (int i = 0; i < ILP; i++) val += vals[i];
|
||||
|
||||
float final = reduce_block_into_lanes(s_vals, val);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
if (!isfinite(final))
|
||||
*noop_gmem =
|
||||
1; // Blindly fire off a write. These will race but that's ok.
|
||||
output[blockIdx.x] += final;
|
||||
if (per_tensor)
|
||||
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
|
||||
max_chunks_per_tensor +
|
||||
chunk_idx] = final;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Probably better to template, but since we are not likely to support other
|
||||
// norm
|
||||
template <typename x_t>
|
||||
struct MaxNormFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
|
||||
float *output, float *output_per_tensor, bool per_tensor,
|
||||
int max_chunks_per_tensor) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
x_t *x = (x_t *)tl.addresses[0][tensor_loc];
|
||||
x += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
__shared__ float s_vals[512];
|
||||
|
||||
float vals[ILP]; // = {0}; // this probably works too but I want to be
|
||||
// sure...
|
||||
x_t r_x[ILP];
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
vals[i] = 0.f;
|
||||
r_x[i] = 0;
|
||||
}
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
|
||||
for (int i_start = threadIdx.x;
|
||||
i_start * ILP < n && i_start * ILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_x, x, 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
float next = static_cast<float>(r_x[ii]);
|
||||
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
float next = static_cast<float>(x[i]);
|
||||
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float val = 0.f;
|
||||
for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i]));
|
||||
|
||||
float final = reduce_block_into_lanes_max_op(s_vals, val);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
if (!isfinite(final))
|
||||
*noop_gmem =
|
||||
1; // Blindly fire off a write. These will race but that's ok.
|
||||
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
|
||||
if (per_tensor)
|
||||
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
|
||||
max_chunks_per_tensor +
|
||||
chunk_idx] = final;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
__global__ void cleanup(float *output, float *output_per_tensor, float *ret,
|
||||
float *ret_per_tensor, bool per_tensor,
|
||||
int max_chunks_per_tensor) {
|
||||
__shared__ float vals[512];
|
||||
|
||||
if (blockIdx.x == 0) {
|
||||
float val = 0;
|
||||
if (threadIdx.x < 320) val = output[threadIdx.x];
|
||||
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
|
||||
if (threadIdx.x == 0) *ret = sqrt(final);
|
||||
}
|
||||
|
||||
if (per_tensor) {
|
||||
float *output_this_tensor =
|
||||
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
|
||||
|
||||
float val = 0;
|
||||
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
|
||||
val += output_this_tensor[i];
|
||||
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
|
||||
if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
|
||||
float *ret_per_tensor, bool per_tensor,
|
||||
int max_chunks_per_tensor, int norm_type,
|
||||
float alpha, float beta) {
|
||||
__shared__ float vals[512];
|
||||
|
||||
if (blockIdx.x == 0) {
|
||||
float val = 0;
|
||||
if (threadIdx.x < 320) val = output[threadIdx.x];
|
||||
|
||||
if (norm_type == 0) {
|
||||
float final = reduce_block_into_lanes_max_op(vals, val);
|
||||
if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final;
|
||||
} else {
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
|
||||
}
|
||||
}
|
||||
|
||||
if (per_tensor) {
|
||||
float *output_this_tensor =
|
||||
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
|
||||
|
||||
if (norm_type == 0) {
|
||||
float val = 0;
|
||||
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
|
||||
val = fmaxf(fabsf(val), fabsf(output_this_tensor[i]));
|
||||
|
||||
float final = reduce_block_into_lanes_max_op(vals, val);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
ret_per_tensor[blockIdx.x] =
|
||||
alpha * ret_per_tensor[blockIdx.x] + beta * final;
|
||||
} else {
|
||||
float val = 0;
|
||||
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
|
||||
val += output_this_tensor[i];
|
||||
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] *
|
||||
ret_per_tensor[blockIdx.x] +
|
||||
beta * final);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
||||
int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python) {
|
||||
bool per_tensor =
|
||||
per_tensor_python.has_value() ? per_tensor_python.value() : false;
|
||||
|
||||
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
|
||||
auto output = at::zeros({320}, float_options);
|
||||
|
||||
at::Tensor output_per_tensor;
|
||||
at::Tensor ret_per_tensor;
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
int max_chunks_per_tensor = -1;
|
||||
|
||||
if (per_tensor) {
|
||||
for (int t = 0; t < ntensors; t++) {
|
||||
int max_chunks_this_tensor =
|
||||
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
if (max_chunks_this_tensor > max_chunks_per_tensor)
|
||||
max_chunks_per_tensor = max_chunks_this_tensor;
|
||||
}
|
||||
output_per_tensor =
|
||||
at::zeros({ntensors * max_chunks_per_tensor}, float_options);
|
||||
ret_per_tensor = at::empty({ntensors}, float_options);
|
||||
} else {
|
||||
ret_per_tensor = at::empty({0}, float_options);
|
||||
}
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
|
||||
multi_tensor_apply<1>(
|
||||
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
L2NormFunctor<scalar_t_0>(), output.DATA_PTR<float>(),
|
||||
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
|
||||
per_tensor, max_chunks_per_tensor);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
// AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// This involves one more small kernel launches, but will be negligible end to
|
||||
// end. I could get rid of these by hacking the functor + multi tensor harness
|
||||
// with persistence logic, but keeping it simple for now
|
||||
auto ret = at::empty({1}, output.options());
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
|
||||
output.DATA_PTR<float>(),
|
||||
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
|
||||
ret.DATA_PTR<float>(),
|
||||
per_tensor ? ret_per_tensor.DATA_PTR<float>() : nullptr, per_tensor,
|
||||
max_chunks_per_tensor);
|
||||
|
||||
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
|
||||
}
|
||||
|
||||
// Compute and update grad norm
|
||||
// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by
|
||||
// L-2: gn = sqrt(a * gn^2 + b * n^2)
|
||||
// L-inf: gn = a * gn + b * n
|
||||
void multi_tensor_norm_out_cuda(
|
||||
int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor out,
|
||||
const float alpha, const float beta, const int norm_type) {
|
||||
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
|
||||
TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(),
|
||||
"noop flag should be on the same device as tensors");
|
||||
// we don't need global thus uses empty here
|
||||
auto output = at::empty({320}, float_options);
|
||||
|
||||
at::Tensor output_per_tensor;
|
||||
at::Tensor ret_per_tensor;
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
int max_chunks_per_tensor = -1;
|
||||
|
||||
for (int t = 0; t < ntensors; t++) {
|
||||
int max_chunks_this_tensor =
|
||||
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
if (max_chunks_this_tensor > max_chunks_per_tensor)
|
||||
max_chunks_per_tensor = max_chunks_this_tensor;
|
||||
}
|
||||
|
||||
// Although it is single write then read, still need to be zero
|
||||
// Since tailing element also participate cleanup
|
||||
output_per_tensor =
|
||||
at::zeros({ntensors * max_chunks_per_tensor}, float_options);
|
||||
|
||||
if (norm_type == 0) {
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
|
||||
multi_tensor_apply<1>(
|
||||
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
MaxNormFunctor<scalar_t_0>(), output.DATA_PTR<float>(),
|
||||
output_per_tensor.DATA_PTR<float>(), true, max_chunks_per_tensor);)
|
||||
} else {
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
|
||||
multi_tensor_apply<1>(
|
||||
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
L2NormFunctor<scalar_t_0>(), output.DATA_PTR<float>(),
|
||||
output_per_tensor.DATA_PTR<float>(), true, max_chunks_per_tensor);)
|
||||
}
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// This involves one more small kernel launches, but will be negligible end to
|
||||
// end. I could get rid of these by hacking the functor + multi tensor harness
|
||||
// with persistence logic, but keeping it simple for now
|
||||
auto ret = at::empty({1}, output.options());
|
||||
|
||||
// Adding the following device guard since it happens sometimes that the
|
||||
// tensors are on one device and the cuda stream is on another device which
|
||||
// results in ILLEGAL MEM ACCESS error.
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
cleanup_v2<<<ntensors, 512, 0, stream>>>(
|
||||
output.DATA_PTR<float>(), output_per_tensor.DATA_PTR<float>(),
|
||||
ret.DATA_PTR<float>(), out.DATA_PTR<float>(), true, max_chunks_per_tensor,
|
||||
norm_type, alpha, beta);
|
||||
|
||||
return;
|
||||
}
|
354
extensions/csrc/cuda/multi_tensor_lamb.cu
Normal file
354
extensions/csrc/cuda/multi_tensor_lamb.cu
Normal file
@@ -0,0 +1,354 @@
|
||||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_lamb.cu
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "type_shim.h"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p) {
|
||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
|
||||
int src_offset) {
|
||||
typedef
|
||||
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
|
||||
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
||||
}
|
||||
|
||||
typedef enum {
|
||||
MOMENT_MODE_0 = 0, // L2 regularization mode
|
||||
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
|
||||
} adamMode_t;
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
||||
int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python);
|
||||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T>
|
||||
struct LAMBStage1Functor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
|
||||
const float beta1, const float beta2, const float beta3,
|
||||
const float beta1_correction, const float beta2_correction,
|
||||
const float epsilon, adamMode_t mode, const float decay,
|
||||
const float *global_grad_norm, const float max_global_grad_norm) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
float clipped_global_grad_norm =
|
||||
(*global_grad_norm) > max_global_grad_norm
|
||||
? (*global_grad_norm) / max_global_grad_norm
|
||||
: 1.0f;
|
||||
|
||||
T *g = (T *)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
||||
T *p = (T *)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
T *m = (T *)tl.addresses[2][tensor_loc];
|
||||
m += chunk_idx * chunk_size;
|
||||
|
||||
T *v = (T *)tl.addresses[3][tensor_loc];
|
||||
v += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
MATH_T r_v[ILP];
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) &&
|
||||
is_aligned(p) && is_aligned(m) && is_aligned(v)) {
|
||||
T l_g[ILP];
|
||||
T l_p[ILP];
|
||||
T l_m[ILP];
|
||||
T l_v[ILP];
|
||||
for (int i_start = threadIdx.x;
|
||||
i_start * ILP < n && i_start * ILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(l_g, g, 0, i_start);
|
||||
if (decay != 0) load_store(l_p, p, 0, i_start);
|
||||
load_store(l_m, m, 0, i_start);
|
||||
load_store(l_v, v, 0, i_start);
|
||||
// unpack
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_g[ii] = l_g[ii];
|
||||
if (decay == 0) {
|
||||
r_p[ii] = MATH_T(0);
|
||||
} else {
|
||||
r_p[ii] = l_p[ii];
|
||||
}
|
||||
r_m[ii] = l_m[ii];
|
||||
r_v[ii] = l_v[ii];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
if (mode == MOMENT_MODE_0) {
|
||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||
// L2 on scaled grad
|
||||
scaled_grad = scaled_grad + decay * r_p[ii];
|
||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
r_p[ii] = next_m_unbiased / denom;
|
||||
} else {
|
||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
l_p[ii] = r_p[ii];
|
||||
l_m[ii] = r_m[ii];
|
||||
l_v[ii] = r_v[ii];
|
||||
}
|
||||
// store
|
||||
load_store(g, l_p, i_start, 0);
|
||||
load_store(m, l_m, i_start, 0);
|
||||
load_store(v, l_v, i_start, 0);
|
||||
}
|
||||
} else {
|
||||
// see note in multi_tensor_scale_kernel.cu
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
MATH_T r_v[ILP];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
r_g[ii] = g[i];
|
||||
// special ?optimization? for lamb stage 1
|
||||
if (decay == 0) {
|
||||
r_p[ii] = MATH_T(0);
|
||||
} else {
|
||||
r_p[ii] = p[i];
|
||||
}
|
||||
r_m[ii] = m[i];
|
||||
r_v[ii] = v[i];
|
||||
} else {
|
||||
r_g[ii] = MATH_T(0);
|
||||
r_p[ii] = MATH_T(0);
|
||||
r_m[ii] = MATH_T(0);
|
||||
r_v[ii] = MATH_T(0);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
if (mode == MOMENT_MODE_0) {
|
||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||
// L2 on scaled grad
|
||||
scaled_grad = scaled_grad + decay * r_p[ii];
|
||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
r_p[ii] = next_m_unbiased / denom;
|
||||
} else {
|
||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
g[i] = r_p[ii];
|
||||
m[i] = r_m[ii];
|
||||
v[i] = r_v[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
|
||||
// It computes new parameter value.
|
||||
template <typename T>
|
||||
struct LAMBStage2Functor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
|
||||
const float *per_tensor_param_norm, const float *per_tensor_update_norm,
|
||||
const float learning_rate, const float decay, bool use_nvlamb) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
MATH_T ratio = learning_rate;
|
||||
// nvlamb: apply adaptive learning rate to all parameters
|
||||
// otherwise, only apply to those with non-zero weight decay
|
||||
if (use_nvlamb || (decay != 0.0)) {
|
||||
float param_norm = per_tensor_param_norm[tensor_num];
|
||||
float update_norm = per_tensor_update_norm[tensor_num];
|
||||
ratio = (update_norm != 0.0f && param_norm != 0.0f)
|
||||
? learning_rate * (param_norm / update_norm)
|
||||
: learning_rate;
|
||||
}
|
||||
|
||||
T *update = (T *)tl.addresses[0][tensor_loc];
|
||||
update += chunk_idx * chunk_size;
|
||||
|
||||
T *p = (T *)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) &&
|
||||
is_aligned(update)) {
|
||||
T r_p[ILP];
|
||||
T r_update[ILP];
|
||||
for (int i_start = threadIdx.x;
|
||||
i_start * ILP < n && i_start * ILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_p, p, 0, i_start);
|
||||
load_store(r_update, update, 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_p[ii] = static_cast<MATH_T>(r_p[ii]) -
|
||||
(ratio * static_cast<MATH_T>(r_update[ii]));
|
||||
}
|
||||
load_store(p, r_p, i_start, 0);
|
||||
}
|
||||
} else {
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_update[ILP];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
r_p[ii] = p[i];
|
||||
r_update[ii] = update[i];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
p[i] = r_p[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr, const float beta1,
|
||||
const float beta2, const float epsilon,
|
||||
const int step, const int bias_correction,
|
||||
const float weight_decay, const int grad_averaging,
|
||||
const int mode, at::Tensor global_grad_norm,
|
||||
const float max_grad_norm,
|
||||
at::optional<bool> use_nvlamb_python) {
|
||||
using namespace at;
|
||||
// Master weight and 32bit momentum(potentially changing) is not handled by
|
||||
// this So we assume every tensor are all in the same type
|
||||
|
||||
bool use_nvlamb =
|
||||
use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
|
||||
|
||||
// Handle bias correction mode
|
||||
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1) {
|
||||
bias_correction1 = 1 - std::pow(beta1, step);
|
||||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
}
|
||||
|
||||
// Handle grad averaging mode
|
||||
float beta3 = 1.0f;
|
||||
if (grad_averaging == 1) beta3 = 1 - beta1;
|
||||
|
||||
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(),
|
||||
tensor_lists.begin() + 1);
|
||||
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin() + 1,
|
||||
tensor_lists.begin() + 2);
|
||||
|
||||
// Compute per tensor param norm
|
||||
auto param_norm_tuple =
|
||||
multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
|
||||
|
||||
// We now in-place modify grad to store update before compute its norm
|
||||
// Generally this is not a issue since people modify grad in step() method all
|
||||
// the time We can also grab list of empty tensor to avoid this, but I'd like
|
||||
// to save space/cpu code
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
|
||||
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
LAMBStage1Functor<scalar_t_0>(), beta1, beta2,
|
||||
beta3, // 1-beta1 or 1 depends on averaging mode
|
||||
bias_correction1, bias_correction2, epsilon,
|
||||
(adamMode_t)mode, weight_decay,
|
||||
global_grad_norm.DATA_PTR<float>(), max_grad_norm);)
|
||||
|
||||
// Compute update norms
|
||||
auto update_norm_tuple =
|
||||
multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);
|
||||
|
||||
std::vector<std::vector<at::Tensor>> grad_param_list(
|
||||
tensor_lists.begin(), tensor_lists.begin() + 2);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
|
||||
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list,
|
||||
LAMBStage2Functor<scalar_t_0>(),
|
||||
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
|
||||
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
|
||||
lr, weight_decay, use_nvlamb);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
125
extensions/csrc/cuda/multi_tensor_scale_kernel.cu
Normal file
125
extensions/csrc/cuda/multi_tensor_scale_kernel.cu
Normal file
@@ -0,0 +1,125 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
|
||||
#include <sstream>
|
||||
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "type_shim.h"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p) {
|
||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
|
||||
int src_offset) {
|
||||
typedef
|
||||
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
|
||||
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
||||
}
|
||||
|
||||
template <typename in_t, typename out_t>
|
||||
struct ScaleFunctor {
|
||||
__device__ __forceinline__ void operator()(int chunk_size,
|
||||
volatile int *noop_gmem,
|
||||
TensorListMetadata<2> &tl,
|
||||
float scale) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
in_t *in = (in_t *)tl.addresses[0][tensor_loc];
|
||||
in += chunk_idx * chunk_size;
|
||||
|
||||
out_t *out = (out_t *)tl.addresses[1][tensor_loc];
|
||||
out += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
bool finite = true;
|
||||
in_t r_in[ILP];
|
||||
out_t r_out[ILP];
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) &&
|
||||
is_aligned(out)) {
|
||||
for (int i_start = threadIdx.x;
|
||||
i_start * ILP < n && i_start * ILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_in, in, 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
|
||||
finite = finite && isfinite(r_in[ii]);
|
||||
}
|
||||
// store
|
||||
load_store(out, r_out, i_start, 0);
|
||||
}
|
||||
} else {
|
||||
// Non-divergent exit condition for __syncthreads, not necessary here
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_in[ii] = 0;
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) r_in[ii] = in[i];
|
||||
}
|
||||
// note for clarification to future michael:
|
||||
// From a pure memory dependency perspective, there's likely no point
|
||||
// unrolling the write loop, since writes just fire off once their LDGs
|
||||
// arrive. Put another way, the STGs are dependent on the LDGs, but not
|
||||
// on each other. There is still compute ILP benefit from unrolling the
|
||||
// loop though.
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
|
||||
finite = finite && isfinite(r_in[ii]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) out[i] = r_out[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!finite)
|
||||
*noop_gmem =
|
||||
1; // Blindly fire off a write. These will race but that's ok.
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float scale) {
|
||||
using namespace at;
|
||||
// The output (downscaled) type is always float.
|
||||
// If build times suffer, think about where to put this dispatch,
|
||||
// and what logic should be moved out of multi_tensor_apply.
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
|
||||
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
ScaleFunctor<scalar_t_0, scalar_t_1>(),
|
||||
scale);))
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
167
extensions/csrc/cuda/multi_tensor_sgd_kernel.cu
Normal file
167
extensions/csrc/cuda/multi_tensor_sgd_kernel.cu
Normal file
@@ -0,0 +1,167 @@
|
||||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <assert.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "compat.h"
|
||||
#include "multi_tensor_apply.cuh"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
/**
|
||||
* Perform fused SGD on multiple buffers
|
||||
* N: number of tensors
|
||||
* tl[0] : gradients
|
||||
* tl[1] : weights
|
||||
* tl[2] : momentum buffers
|
||||
* tl[3] : fp16 weights (if appropriate)
|
||||
* wd : weight_decay (scalar)
|
||||
* momentum : momentum (scalar)
|
||||
* dampening : momentum dampening (scalar)
|
||||
* lr : learning rate (scalar)
|
||||
* nesterov : enable nesterov (bool)
|
||||
* first run : necessary for proper momentum handling & init
|
||||
* wd_after_momentum : apply weight decay _after_ momentum instead of before
|
||||
**/
|
||||
template <typename T_grad, typename T_weight>
|
||||
struct SGDFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl,
|
||||
float wd, float momentum, float dampening, float lr, bool nesterov,
|
||||
bool first_run, bool wd_after_momentum, float scale) {
|
||||
// Early exit if we don't need to do anything
|
||||
if (*noop_gmem) return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
|
||||
grad_in += chunk_idx * chunk_size;
|
||||
|
||||
T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
|
||||
weight_in += chunk_idx * chunk_size;
|
||||
|
||||
T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
|
||||
mom_in += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// Non-divergent exit condition for the __syncthreads
|
||||
float incoming_grads[ILP];
|
||||
float incoming_weights[ILP];
|
||||
float incoming_moms[ILP];
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
incoming_grads[ii] = 0;
|
||||
incoming_weights[ii] = 0;
|
||||
incoming_moms[ii] = 0;
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
|
||||
incoming_weights[ii] = static_cast<float>(weight_in[i]);
|
||||
incoming_moms[ii] = static_cast<float>(mom_in[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// note for clarification to future michael:
|
||||
// From a pure memory dependency perspective, there's likely no point unrolling
|
||||
// the write loop, since writes just fire off once their LDGs arrive.
|
||||
// Put another way, the STGs are dependent on the LDGs, but not on each other.
|
||||
// There is still compute ILP benefit from unrolling the loop though.
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
// apply weight decay before momentum if necessary
|
||||
if (wd != 0.f && !wd_after_momentum)
|
||||
incoming_grads[ii] += wd * incoming_weights[ii];
|
||||
|
||||
if (momentum != 0.f) {
|
||||
if (!first_run)
|
||||
incoming_moms[ii] = incoming_moms[ii] * momentum +
|
||||
(1.f - dampening) * incoming_grads[ii];
|
||||
else // initialize momentums to current incoming grads
|
||||
incoming_moms[ii] = incoming_grads[ii];
|
||||
|
||||
if (nesterov)
|
||||
incoming_grads[ii] += momentum * incoming_moms[ii];
|
||||
else
|
||||
incoming_grads[ii] = incoming_moms[ii];
|
||||
}
|
||||
|
||||
// Apply WD after momentum if desired
|
||||
if (wd != 0.f && wd_after_momentum)
|
||||
incoming_grads[ii] += wd * incoming_weights[ii];
|
||||
|
||||
// adjust the weight and write out
|
||||
weight_in[i] += (-lr * incoming_grads[ii]);
|
||||
|
||||
// also write out the new momentum
|
||||
if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float wd, float momentum, float dampening, float lr,
|
||||
bool nesterov, bool first_run,
|
||||
bool wd_after_momentum, float scale) {
|
||||
auto num_tensors = tensor_lists.size();
|
||||
auto grad_type = tensor_lists[0][0].scalar_type();
|
||||
auto weight_type = tensor_lists[1][0].scalar_type();
|
||||
|
||||
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
|
||||
"expected noop flag to be on the same device as tensors");
|
||||
|
||||
// We have 3 possibilities to handle here, in terms of
|
||||
// grad_type, param_type, momentum_type
|
||||
// 1. fp16, fp16, fp16
|
||||
// 2. fp32, fp32, fp32
|
||||
// 3. fp16, fp32, fp32
|
||||
// It's easier to hardcode these possibilities than to use
|
||||
// switches etc. to handle the cross-product of cases where
|
||||
// we don't want the majority of them.
|
||||
|
||||
// Case 1. fp16, fp16, fp16, No
|
||||
if (grad_type == at::ScalarType::Half &&
|
||||
weight_type == at::ScalarType::Half && num_tensors == 3) {
|
||||
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
SGDFunctor<at::Half, at::Half>(), wd, momentum,
|
||||
dampening, lr, nesterov, first_run, wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
// Case 2. fp32, fp32, fp32
|
||||
else if (grad_type == at::ScalarType::Float &&
|
||||
weight_type == at::ScalarType::Float && num_tensors == 3) {
|
||||
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
SGDFunctor<float, float>(), wd, momentum, dampening,
|
||||
lr, nesterov, first_run, wd_after_momentum, scale);
|
||||
}
|
||||
// Case 3. fp16, fp32, fp32
|
||||
else if (grad_type == at::ScalarType::Half &&
|
||||
weight_type == at::ScalarType::Float && num_tensors == 3) {
|
||||
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
SGDFunctor<at::Half, float>(), wd, momentum,
|
||||
dampening, lr, nesterov, first_run, wd_after_momentum,
|
||||
scale);
|
||||
} else {
|
||||
AT_ERROR(
|
||||
"multi_tensor_sgd only supports some combinations of gradient & weight "
|
||||
"types. Given: ",
|
||||
"gradient: ", grad_type, ", weight: ", weight_type,
|
||||
", num_lists: ", num_tensors);
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
70
extensions/csrc/cuda/scaled_masked_softmax.cpp
Normal file
70
extensions/csrc/cuda/scaled_masked_softmax.cpp
Normal file
@@ -0,0 +1,70 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
|
||||
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads);
|
||||
|
||||
torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask,
|
||||
float scale_factor) {
|
||||
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
|
||||
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
|
||||
|
||||
return fwd_cuda(input, mask, scale_factor);
|
||||
}
|
||||
|
||||
torch::Tensor bwd(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results, float scale_factor) {
|
||||
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
|
||||
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
|
||||
|
||||
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||
}
|
||||
|
||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads) {
|
||||
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches,
|
||||
attn_heads);
|
||||
}
|
||||
|
||||
} // end namespace scaled_masked_softmax
|
||||
} // end namespace fused_softmax
|
||||
} // end namespace multihead_attn
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
|
||||
m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
|
||||
m.def("get_batch_per_block",
|
||||
&multihead_attn::fused_softmax::scaled_masked_softmax::
|
||||
get_batch_per_block,
|
||||
"Return Batch per block size.");
|
||||
}
|
538
extensions/csrc/cuda/scaled_masked_softmax.h
Normal file
538
extensions/csrc/cuda/scaled_masked_softmax.h
Normal file
@@ -0,0 +1,538 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cfloat>
|
||||
#include <limits>
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
|
||||
const uint8_t *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
|
||||
const uint8_t *src) {
|
||||
*((half2 *)dst) = *((half2 *)src);
|
||||
}
|
||||
|
||||
int log2_ceil(int value) {
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < value) ++log2_value;
|
||||
return log2_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Add {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Max {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a < b ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T
|
||||
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
|
||||
unsigned int mask = 0xffffffff) {
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
#else
|
||||
return __shfl_xor(value, laneMask, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
|
||||
template <typename> class ReduceOp>
|
||||
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
|
||||
ReduceOp<acc_t> r;
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
||||
sum[i] = r(sum[i], b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Extended softmax (from native aten pytorch) with following additional
|
||||
* features 1) input scaling 2) Explicit masking
|
||||
*/
|
||||
template <typename input_t, typename output_t, typename acc_t,
|
||||
int log2_elements>
|
||||
__global__ void scaled_masked_softmax_warp_forward(
|
||||
output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale,
|
||||
int micro_batch_size, int element_count, int pad_batches) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||
int first_batch =
|
||||
(blockDim.y *
|
||||
(blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
|
||||
threadIdx.y) *
|
||||
WARP_BATCH;
|
||||
int pad_first_batch = 0;
|
||||
if (pad_batches != 1) { // bert style
|
||||
pad_first_batch =
|
||||
(blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
|
||||
WARP_BATCH;
|
||||
} else { // gpt2 style
|
||||
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
}
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the
|
||||
// batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < batch_element_count) {
|
||||
int itr_idx = i * element_count + it * WARP_SIZE;
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
|
||||
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (temp_mask[element] != 1) {
|
||||
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
||||
} else {
|
||||
elements[i][it + element] = -10000.0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] =
|
||||
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
acc_t sum[WARP_BATCH]{0.0f};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] = elements[i][it + element] / sum[i];
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
dst + i * element_count + it * WARP_SIZE, out);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t,
|
||||
int log2_elements>
|
||||
__global__ void scaled_masked_softmax_warp_backward(
|
||||
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
|
||||
int micro_batch_size, int element_count) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the
|
||||
// batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
// the first element to process by the current thread
|
||||
int thread_offset =
|
||||
first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_grad, grad + i * element_count + it * WARP_SIZE);
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_output, output + i * element_count + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
grad_reg[i][it + element] =
|
||||
(acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] =
|
||||
(output_t)(scale * (grad_reg[i][it + element] -
|
||||
output_reg[i][it + element] * sum[i]));
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
gradInput + i * element_count + it * WARP_SIZE, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end of anonymous namespace
|
||||
|
||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads) {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
constexpr int threads_per_block = 128;
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
|
||||
return batches_per_block;
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src,
|
||||
const uint8_t *mask,
|
||||
const input_t scale,
|
||||
int query_seq_len, int key_seq_len,
|
||||
int batches, int attn_heads,
|
||||
int pad_batches) {
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
||||
if (key_seq_len == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int batch_count = batches * attn_heads * query_seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside
|
||||
// softmax_warp_forward.
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside
|
||||
// softmax_warp_forward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0);
|
||||
dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_masked_softmax_backward(output_t *grad_input,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
const acc_t scale,
|
||||
int query_seq_len, int key_seq_len,
|
||||
int batches, int attn_heads) {
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
||||
if (key_seq_len == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int batch_count = batches * attn_heads * query_seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside
|
||||
// softmax_warp_backward.
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside
|
||||
// softmax_warp_backward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
int blocks = batch_count / batches_per_block;
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
89
extensions/csrc/cuda/scaled_masked_softmax_cuda.cu
Normal file
89
extensions/csrc/cuda/scaled_masked_softmax_cuda.cu
Normal file
@@ -0,0 +1,89 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "scaled_masked_softmax.h"
|
||||
#include "type_shim.h"
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_masked_softmax {
|
||||
|
||||
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads) {
|
||||
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
|
||||
}
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
||||
float scale_factor) {
|
||||
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len,
|
||||
// seq_len]
|
||||
const int batches = input.size(0);
|
||||
const int pad_batches = mask.size(0);
|
||||
const int attn_heads = input.size(1);
|
||||
const int query_seq_len = input.size(2);
|
||||
const int key_seq_len = input.size(3);
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
|
||||
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
|
||||
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
|
||||
|
||||
// Output
|
||||
auto act_options = input.options().requires_grad(false);
|
||||
torch::Tensor softmax_results = torch::empty(
|
||||
{batches, attn_heads, query_seq_len, key_seq_len}, act_options);
|
||||
|
||||
// Softmax Intermediate Result Ptr
|
||||
void* input_ptr = static_cast<void*>(input.data_ptr());
|
||||
void* mask_ptr = static_cast<void*>(mask.data_ptr());
|
||||
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
|
||||
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
input.scalar_type(), "dispatch_scaled_masked_softmax_forward",
|
||||
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(softmax_results_ptr),
|
||||
reinterpret_cast<const scalar_t*>(input_ptr),
|
||||
reinterpret_cast<const uint8_t*>(mask_ptr), scale_factor,
|
||||
query_seq_len, key_seq_len, batches, attn_heads, pad_batches););
|
||||
return softmax_results;
|
||||
}
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
|
||||
torch::Tensor const& softmax_results_,
|
||||
float scale_factor) {
|
||||
auto output_grads = output_grads_.contiguous();
|
||||
auto softmax_results = softmax_results_.contiguous();
|
||||
|
||||
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len,
|
||||
// seq_len]
|
||||
const int batches = output_grads.size(0);
|
||||
const int attn_heads = output_grads.size(1);
|
||||
const int query_seq_len = output_grads.size(2);
|
||||
const int key_seq_len = output_grads.size(3);
|
||||
|
||||
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
|
||||
|
||||
// Softmax Grad
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward",
|
||||
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
||||
scale_factor, query_seq_len, key_seq_len, batches, attn_heads););
|
||||
|
||||
// backward pass is completely in-place
|
||||
return output_grads;
|
||||
}
|
||||
} // namespace scaled_masked_softmax
|
||||
} // namespace fused_softmax
|
||||
} // namespace multihead_attn
|
54
extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp
Normal file
54
extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp
Normal file
@@ -0,0 +1,54 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_upper_triang_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
|
||||
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
|
||||
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return fwd_cuda(input, scale_factor);
|
||||
}
|
||||
|
||||
torch::Tensor bwd(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results, float scale_factor) {
|
||||
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
|
||||
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
|
||||
|
||||
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||
}
|
||||
|
||||
} // end namespace scaled_upper_triang_masked_softmax
|
||||
} // end namespace fused_softmax
|
||||
} // end namespace multihead_attn
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward",
|
||||
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
m.def("backward",
|
||||
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
}
|
600
extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h
Normal file
600
extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h
Normal file
@@ -0,0 +1,600 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cfloat>
|
||||
#include <limits>
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
|
||||
const uint8_t *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
|
||||
const uint8_t *src) {
|
||||
*((half2 *)dst) = *((half2 *)src);
|
||||
}
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_zero_vector(Datatype *dst);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(
|
||||
c10::BFloat16 *dst) {
|
||||
*dst = 0.0;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(
|
||||
c10::BFloat16 *dst) {
|
||||
*((float2 *)dst) = make_float2(0.0f, 0.0f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) {
|
||||
*dst = 0.0;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) {
|
||||
*((float2 *)dst) = make_float2(0.0f, 0.0f);
|
||||
}
|
||||
|
||||
int log2_ceil(int value) {
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < value) ++log2_value;
|
||||
return log2_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Add {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Max {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a < b ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T
|
||||
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
|
||||
unsigned int mask = 0xffffffff) {
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
#else
|
||||
return __shfl_xor(value, laneMask, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
|
||||
template <typename> class ReduceOp>
|
||||
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
|
||||
ReduceOp<acc_t> r;
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
||||
sum[i] = r(sum[i], b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Extended softmax (from native aten pytorch) with following additional
|
||||
* features 1) input scaling 2) Implicit time (diagonal masking)
|
||||
*/
|
||||
template <typename input_t, typename output_t, typename acc_t,
|
||||
int log2_elements>
|
||||
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
|
||||
output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size,
|
||||
int stride, int element_count) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
int first_batch =
|
||||
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
|
||||
blockIdx.x;
|
||||
int local_seq = blockIdx.x + 1;
|
||||
int warp_iteration_limit =
|
||||
(local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the
|
||||
// batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_data, src + i * element_count * stride + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if ((element_index + element) < batch_element_count) {
|
||||
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
||||
} else {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] =
|
||||
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
acc_t sum[WARP_BATCH]{0.0f};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
if (it < warp_iteration_limit) {
|
||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < local_seq) {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < local_seq) {
|
||||
out[element] = elements[i][it + element] / sum[i];
|
||||
} else {
|
||||
out[element] = 0;
|
||||
}
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
dst + i * element_count * stride + it * WARP_SIZE, out);
|
||||
} else if (element_index < element_count) {
|
||||
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
dst + i * element_count * stride + it * WARP_SIZE);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t,
|
||||
int log2_elements>
|
||||
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
|
||||
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
|
||||
int micro_batch_size, int stride, int element_count) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
int first_batch =
|
||||
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
|
||||
blockIdx.x;
|
||||
int local_seq = blockIdx.x + 1;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the
|
||||
// batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
// the first element to process by the current thread
|
||||
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_output, output + i * element_count * stride + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < batch_element_count) {
|
||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < batch_element_count) {
|
||||
grad_reg[i][it + element] =
|
||||
(acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] =
|
||||
(output_t)(scale * (grad_reg[i][it + element] -
|
||||
output_reg[i][it + element] * sum[i]));
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
gradInput + i * element_count * stride + it * WARP_SIZE, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // end of anonymous namespace
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_upper_triang_masked_softmax_forward(
|
||||
output_t *dst, const input_t *src, const input_t scale,
|
||||
int softmax_elements, int softmax_elements_stride, int attn_batches) {
|
||||
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int seq_len = softmax_elements;
|
||||
int batch_count = attn_batches * seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside
|
||||
// softmax_warp_forward.
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside
|
||||
// softmax_warp_forward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||
|
||||
int blocks_per_seq = attn_batches / batches_per_block;
|
||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_upper_triang_masked_softmax_backward(
|
||||
output_t *grad_input, input_t *grad, const input_t *output,
|
||||
const acc_t scale, int softmax_elements, int softmax_elements_stride,
|
||||
int attn_batches) {
|
||||
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int seq_len = softmax_elements;
|
||||
int batch_count = attn_batches * seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside
|
||||
// softmax_warp_backward.
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside
|
||||
// softmax_warp_backward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||
|
||||
int blocks_per_seq = attn_batches / batches_per_block;
|
||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,75 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "scaled_upper_triang_masked_softmax.h"
|
||||
#include "type_shim.h"
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_upper_triang_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
|
||||
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
|
||||
const int attn_batches = input.size(0);
|
||||
const int seq_len = input.size(1);
|
||||
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
|
||||
|
||||
// Output
|
||||
auto act_options = input.options().requires_grad(false);
|
||||
torch::Tensor softmax_results =
|
||||
torch::empty({attn_batches, seq_len, seq_len}, act_options);
|
||||
|
||||
// Softmax Intermediate Result Ptr
|
||||
void* input_ptr = static_cast<void*>(input.data_ptr());
|
||||
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
|
||||
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"dispatch_scaled_upper_triang_masked_softmax_forward",
|
||||
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t,
|
||||
float>(
|
||||
reinterpret_cast<scalar_t*>(softmax_results_ptr),
|
||||
reinterpret_cast<const scalar_t*>(input_ptr), scale_factor, seq_len,
|
||||
seq_len, attn_batches););
|
||||
return softmax_results;
|
||||
}
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
|
||||
torch::Tensor const& softmax_results_,
|
||||
float scale_factor) {
|
||||
auto output_grads = output_grads_.contiguous();
|
||||
auto softmax_results = softmax_results_.contiguous();
|
||||
|
||||
// output grads is a 3d tensor with dimensions [attn_batches, seq_len,
|
||||
// seq_len]
|
||||
const int attn_batches = output_grads.size(0);
|
||||
const int seq_len = output_grads.size(1);
|
||||
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
|
||||
|
||||
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
|
||||
|
||||
// Softmax Grad
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
output_grads_.scalar_type(),
|
||||
"dispatch_scaled_upper_triang_masked_softmax_backward",
|
||||
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t,
|
||||
float>(
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
||||
scale_factor, seq_len, seq_len, attn_batches););
|
||||
|
||||
// backward pass is completely in-place
|
||||
return output_grads;
|
||||
}
|
||||
} // namespace scaled_upper_triang_masked_softmax
|
||||
} // namespace fused_softmax
|
||||
} // namespace multihead_attn
|
279
extensions/csrc/cuda/type_shim.h
Normal file
279
extensions/csrc/cuda/type_shim.h
Normal file
@@ -0,0 +1,279 @@
|
||||
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
|
||||
/* Copyright 2020 The Microsoft DeepSpeed Team
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include "compat.h"
|
||||
|
||||
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||
switch (TYPEIN) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_in = float; \
|
||||
switch (TYPEOUT) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_out = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_in = at::Half; \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t_in = at::BFloat16; \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
||||
}
|
||||
|
||||
// Forward/backward compatiblity hack around
|
||||
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
|
||||
// pending more future-proof guidance from upstream.
|
||||
// struct TypeShim
|
||||
// {
|
||||
// const at::Type& payload;
|
||||
// TypeShim(const at::Type& type) : payload(type) {}
|
||||
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
|
||||
// operator const at::Type&(){ return payload; };
|
||||
// // Enable dispatch switch statements to take *this directly for post-3aeb78
|
||||
// //operator at::ScalarType(){ return payload.; };
|
||||
// };
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Byte: { \
|
||||
using scalar_t_##LEVEL = uint8_t; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Double: { \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Double: { \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
|
||||
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::Float && \
|
||||
PTYPE == at::ScalarType::Half) { \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::Half && \
|
||||
PTYPE == at::ScalarType::Float) { \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::Float && \
|
||||
PTYPE == at::ScalarType::BFloat16) { \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::BFloat16 && \
|
||||
PTYPE == at::ScalarType::Float) { \
|
||||
using g_scalar_t_##LEVEL = at::BFloat16; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::BFloat16 && \
|
||||
PTYPE == at::ScalarType::BFloat16) { \
|
||||
using g_scalar_t_##LEVEL = at::BFloat16; \
|
||||
using p_scalar_t_##LEVEL = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
||||
"'"); \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(
|
||||
T *x, T val, int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize =
|
||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64) {
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||
if (tid < i) x[tid] = x[tid] + x[tid + i];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32) {
|
||||
if (blockSize >= 64)
|
||||
final = x[tid] + x[tid + 32];
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||
}
|
||||
|
||||
if (share_result) {
|
||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
|
||||
T *x, T val, int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize =
|
||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64) {
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32) {
|
||||
if (blockSize >= 64)
|
||||
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final =
|
||||
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
||||
}
|
||||
|
||||
if (share_result) {
|
||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
190
extensions/csrc/scaled_softmax.py
Normal file
190
extensions/csrc/scaled_softmax.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# This code from NVIDIA Megatron:
|
||||
# with minor changes.
|
||||
|
||||
import enum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader
|
||||
|
||||
try:
|
||||
from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax
|
||||
except ImportError:
|
||||
scaled_masked_softmax = None
|
||||
scaled_upper_triang_masked_softmax = None
|
||||
|
||||
|
||||
class AttnMaskType(enum.Enum):
|
||||
padding = 1
|
||||
causal = 2
|
||||
paddedcausal = 3
|
||||
|
||||
|
||||
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
"""
|
||||
Fused operation which performs following three operations in sequence
|
||||
|
||||
1. Scale the tensor.
|
||||
2. Apply upper triangular mask (typically used in gpt models).
|
||||
3. Perform softmax.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, scale):
|
||||
global scaled_upper_triang_masked_softmax
|
||||
if scaled_upper_triang_masked_softmax:
|
||||
scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load()
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
|
||||
return input_grads, None
|
||||
|
||||
|
||||
class ScaledMaskedSoftmax(torch.autograd.Function):
|
||||
"""
|
||||
Fused operation which performs following three operations in sequence
|
||||
|
||||
1. Scale the tensor.
|
||||
2. Apply the mask.
|
||||
3. Perform softmax.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, mask, scale):
|
||||
scale_t = torch.tensor([scale])
|
||||
|
||||
# build and load kernel if not pre-built
|
||||
global scaled_masked_softmax
|
||||
if scaled_masked_softmax is None:
|
||||
scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
|
||||
|
||||
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
|
||||
input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
return input_grads, None, None, None
|
||||
|
||||
|
||||
class FusedScaleMaskSoftmax(nn.Module):
|
||||
"""
|
||||
Fused operation: scaling + mask + softmax
|
||||
|
||||
Arguments:
|
||||
input_in_fp16: Flag to indicate if input in fp16 data format.
|
||||
input_in_bf16: Flag to indicate if input in bf16 data format.
|
||||
attn_mask_type: Attention mask type (pad or causal)
|
||||
scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion
|
||||
mask_func: Mask function to be applied.
|
||||
softmax_in_fp32: If True, softmax in performed at fp32 precision.
|
||||
scale: Scaling factor used in input tensor scaling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_in_fp16,
|
||||
input_in_bf16,
|
||||
attn_mask_type,
|
||||
scaled_masked_softmax_fusion,
|
||||
mask_func,
|
||||
softmax_in_fp32,
|
||||
scale,
|
||||
):
|
||||
super(FusedScaleMaskSoftmax, self).__init__()
|
||||
self.input_in_fp16 = input_in_fp16
|
||||
self.input_in_bf16 = input_in_bf16
|
||||
assert not (
|
||||
self.input_in_fp16 and self.input_in_bf16
|
||||
), "both fp16 and bf16 flags cannot be active at the same time."
|
||||
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
|
||||
self.attn_mask_type = attn_mask_type
|
||||
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
|
||||
self.mask_func = mask_func
|
||||
self.softmax_in_fp32 = softmax_in_fp32
|
||||
self.scale = scale
|
||||
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
|
||||
|
||||
def forward(self, input, mask):
|
||||
# [b, np, sq, sk]
|
||||
assert input.dim() == 4
|
||||
|
||||
if self.is_kernel_available(mask, *input.size()):
|
||||
return self.forward_fused_softmax(input, mask)
|
||||
else:
|
||||
return self.forward_torch_softmax(input, mask)
|
||||
|
||||
def is_kernel_available(self, mask, b, np, sq, sk):
|
||||
attn_batches = b * np
|
||||
|
||||
if (
|
||||
self.scaled_masked_softmax_fusion # user want to fuse
|
||||
and self.input_in_float16 # input must be fp16
|
||||
and mask is not None # mask tensor must not be None
|
||||
and 16 < sk <= 2048 # sk must be 16 ~ 2048
|
||||
and sq % 4 == 0 # sq must be divisor of 4
|
||||
and attn_batches % 4 == 0 # np * b must be divisor of 4
|
||||
):
|
||||
if 0 <= sk <= 2048:
|
||||
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
|
||||
|
||||
if self.attn_mask_type.value > 1:
|
||||
if attn_batches % batch_per_block == 0:
|
||||
return True
|
||||
else:
|
||||
if sq % batch_per_block == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def forward_fused_softmax(self, input, mask):
|
||||
b, np, sq, sk = input.size()
|
||||
scale = self.scale if self.scale is not None else 1.0
|
||||
|
||||
if self.attn_mask_type.value > 1:
|
||||
assert sq == sk, "causal mask is only for self attention"
|
||||
|
||||
# input is 3D tensor (attn_batches, sq, sk)
|
||||
input = input.view(-1, sq, sk)
|
||||
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
|
||||
return probs.view(b, np, sq, sk)
|
||||
else:
|
||||
# input is 4D tensor (b, np, sq, sk)
|
||||
return ScaledMaskedSoftmax.apply(input, mask, scale)
|
||||
|
||||
def forward_torch_softmax(self, input, mask):
|
||||
if self.input_in_float16 and self.softmax_in_fp32:
|
||||
input = input.float()
|
||||
|
||||
if self.scale is not None:
|
||||
input = input * self.scale
|
||||
mask_output = self.mask_func(input, mask) if mask is not None else input
|
||||
probs = torch.nn.Softmax(dim=-1)(mask_output)
|
||||
|
||||
if self.input_in_float16 and self.softmax_in_fp32:
|
||||
if self.input_in_fp16:
|
||||
probs = probs.half()
|
||||
else:
|
||||
probs = probs.bfloat16()
|
||||
|
||||
return probs
|
||||
|
||||
def get_batch_per_block(self, sq, sk, b, np):
|
||||
# build and load kernel if not pre-built
|
||||
global scaled_masked_softmax
|
||||
if scaled_masked_softmax is None:
|
||||
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
||||
|
||||
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
106
extensions/cuda_extension.py
Normal file
106
extensions/cuda_extension.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
from .cpp_extension import _CppExtension
|
||||
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list
|
||||
|
||||
__all__ = ["_CudaExtension"]
|
||||
|
||||
# Some constants for installation checks
|
||||
MIN_PYTORCH_VERSION_MAJOR = 1
|
||||
MIN_PYTORCH_VERSION_MINOR = 10
|
||||
|
||||
|
||||
class _CudaExtension(_CppExtension):
|
||||
@abstractmethod
|
||||
def nvcc_flags(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of nvcc compilation flags for extensions.
|
||||
"""
|
||||
|
||||
def is_hardware_available(self) -> bool:
|
||||
# cuda extension can only be built if cuda is availabe
|
||||
try:
|
||||
import torch
|
||||
|
||||
cuda_available = torch.cuda.is_available()
|
||||
except:
|
||||
cuda_available = False
|
||||
return cuda_available
|
||||
|
||||
def assert_hardware_compatible(self) -> None:
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
if not CUDA_HOME:
|
||||
raise AssertionError(
|
||||
"[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions"
|
||||
)
|
||||
check_system_pytorch_cuda_match(CUDA_HOME)
|
||||
check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR)
|
||||
|
||||
def get_cuda_home_include(self):
|
||||
"""
|
||||
return include path inside the cuda home.
|
||||
"""
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
|
||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||
return cuda_include
|
||||
|
||||
def build_jit(self) -> None:
|
||||
from torch.utils.cpp_extension import CUDA_HOME, load
|
||||
|
||||
set_cuda_arch_list(CUDA_HOME)
|
||||
|
||||
# get build dir
|
||||
build_directory = _Extension.get_jit_extension_folder_path()
|
||||
build_directory = Path(build_directory)
|
||||
build_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# check if the kernel has been built
|
||||
compiled_before = False
|
||||
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
|
||||
if kernel_file_path.exists():
|
||||
compiled_before = True
|
||||
|
||||
# load the kernel
|
||||
if compiled_before:
|
||||
print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
|
||||
else:
|
||||
print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
|
||||
|
||||
build_start = time.time()
|
||||
op_kernel = load(
|
||||
name=self.name,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_cflags=self.cxx_flags(),
|
||||
extra_cuda_cflags=self.nvcc_flags(),
|
||||
extra_ldflags=[],
|
||||
build_directory=str(build_directory),
|
||||
)
|
||||
build_duration = time.time() - build_start
|
||||
|
||||
if compiled_before:
|
||||
print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
|
||||
else:
|
||||
print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
|
||||
|
||||
return op_kernel
|
||||
|
||||
def build_aot(self) -> "CUDAExtension":
|
||||
from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension
|
||||
|
||||
set_cuda_arch_list(CUDA_HOME)
|
||||
return CUDAExtension(
|
||||
name=self.prebuilt_import_path,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
include_dirs=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_compile_args={
|
||||
"cxx": self.strip_empty_entries(self.cxx_flags()),
|
||||
"nvcc": self.strip_empty_entries(self.nvcc_flags()),
|
||||
},
|
||||
)
|
20
extensions/flash_attention/__init__.py
Normal file
20
extensions/flash_attention/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension
|
||||
from .flash_attention_npu import FlashAttentionNpuExtension
|
||||
from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension
|
||||
|
||||
try:
|
||||
import flash_attention # noqa
|
||||
|
||||
HAS_FLASH_ATTN = True
|
||||
except:
|
||||
HAS_FLASH_ATTN = False
|
||||
|
||||
try:
|
||||
import xformers # noqa
|
||||
|
||||
HAS_MEM_EFF_ATTN = True
|
||||
except:
|
||||
HAS_MEM_EFF_ATTN = False
|
||||
|
||||
|
||||
__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"]
|
93
extensions/flash_attention/flash_attention_dao_cuda.py
Normal file
93
extensions/flash_attention/flash_attention_dao_cuda.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from ..base_extension import _Extension
|
||||
|
||||
|
||||
class FlashAttentionDaoCudaExtension(_Extension):
|
||||
def __init__(self):
|
||||
super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10)
|
||||
|
||||
def is_hardware_available(self) -> bool:
|
||||
# cuda extension can only be built if cuda is availabe
|
||||
try:
|
||||
import torch
|
||||
|
||||
cuda_available = torch.cuda.is_available()
|
||||
except:
|
||||
cuda_available = False
|
||||
return cuda_available
|
||||
|
||||
def assert_hardware_compatible(self) -> bool:
|
||||
pass
|
||||
|
||||
def build_aot(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'."
|
||||
)
|
||||
|
||||
def build_jit(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'"
|
||||
)
|
||||
|
||||
def load(self):
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
(
|
||||
"We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'"
|
||||
)
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
def flash_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len_info_q: "SeqLenInfo",
|
||||
seq_len_info_kv: "SeqLenInfo",
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float = None,
|
||||
causal: bool = False,
|
||||
padded: bool = False,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch, q_seqlen, nheads, headdim)
|
||||
k: (batch, kv_seqlen, nheads, headdim)
|
||||
v: (batch, kv_seqlen, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
dropout_p: float. Dropout probability.
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
Return:
|
||||
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||
"""
|
||||
# check if the input is in allowed dtypes
|
||||
if padded:
|
||||
if seq_len_info_kv == None:
|
||||
seq_len_info_kv = seq_len_info_q
|
||||
|
||||
attn_out = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
seq_len_info_q.cu_seqlens,
|
||||
seq_len_info_kv.cu_seqlens,
|
||||
seq_len_info_q.max_seqlen,
|
||||
seq_len_info_kv.max_seqlen,
|
||||
dropout_p,
|
||||
scale,
|
||||
causal,
|
||||
)
|
||||
else:
|
||||
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
|
||||
return attn_out
|
||||
|
||||
return flash_attention
|
73
extensions/flash_attention/flash_attention_npu.py
Normal file
73
extensions/flash_attention/flash_attention_npu.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from ..base_extension import _Extension
|
||||
|
||||
|
||||
class FlashAttentionNpuExtension(_Extension):
|
||||
def __init__(self):
|
||||
super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False)
|
||||
|
||||
def is_hardware_available(self) -> bool:
|
||||
try:
|
||||
import torch_npu # noqa
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def assert_hardware_compatible(self) -> bool:
|
||||
pass
|
||||
|
||||
def build_aot(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Flash Attention NPU does not require ahead-of-time compilation. Please use it by installing torch_npu."
|
||||
)
|
||||
|
||||
def build_jit(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Flash Attention NPU does not require just-in-time compilation. Please use it by installing torch_npu."
|
||||
)
|
||||
|
||||
def load(self):
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
def npu_sdpa_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len_info_q=None,
|
||||
seq_len_info_kv=None,
|
||||
origin_attn_mask: torch.Tensor = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float = 1.0,
|
||||
causal=None,
|
||||
padded=None,
|
||||
):
|
||||
"""
|
||||
The scaled dot product attention.
|
||||
|
||||
Arguments:
|
||||
q: (batch, q_seqlen, nheads, headdim)
|
||||
k: (batch, kv_seqlen, nheads, headdim)
|
||||
v: (batch, kv_seqlen, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
dropout_p: float. Dropout probability.
|
||||
scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1.
|
||||
Return:
|
||||
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||
"""
|
||||
q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
|
||||
output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=origin_attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=origin_attn_mask is None,
|
||||
scale=scale,
|
||||
)
|
||||
output = rearrange(output, "b h s d -> b s (h d)")
|
||||
return output
|
||||
|
||||
return npu_sdpa_attention
|
94
extensions/flash_attention/flash_attention_xformers_cuda.py
Normal file
94
extensions/flash_attention/flash_attention_xformers_cuda.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from ..base_extension import _Extension
|
||||
|
||||
|
||||
class FlashAttentionXformersCudaExtension(_Extension):
|
||||
def __init__(self):
|
||||
super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False)
|
||||
|
||||
def is_hardware_available(self) -> bool:
|
||||
# cuda extension can only be built if cuda is availabe
|
||||
try:
|
||||
import torch
|
||||
|
||||
cuda_available = torch.cuda.is_available()
|
||||
except:
|
||||
cuda_available = False
|
||||
return cuda_available
|
||||
|
||||
def assert_hardware_compatible(self) -> bool:
|
||||
pass
|
||||
|
||||
def build_aot(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
|
||||
)
|
||||
|
||||
def build_jit(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
|
||||
)
|
||||
|
||||
def load(self):
|
||||
try:
|
||||
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
|
||||
from xformers.ops.fmha.attn_bias import (
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMask,
|
||||
LowerTriangularMaskWithTensorBias,
|
||||
)
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
(
|
||||
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
|
||||
)
|
||||
)
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
allow_alibi = True
|
||||
for op in MemoryEfficientAttentionCutlassOp:
|
||||
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
|
||||
|
||||
def mem_eff_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len_info_q: "SeqLenInfo",
|
||||
seq_len_info_kv: "SeqLenInfo",
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float = None,
|
||||
causal: bool = False,
|
||||
padded: bool = False,
|
||||
):
|
||||
attn_bias = None
|
||||
if padded: # bert style
|
||||
if not causal:
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
|
||||
else:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
|
||||
elif causal: # gpt style
|
||||
attn_bias = LowerTriangularMask()
|
||||
|
||||
if bias is not None: # alibi / relative position embedding
|
||||
assert allow_alibi, "flash attention with bias is not supported in this system."
|
||||
assert causal, "attention with bias is only supported for causal attention so far."
|
||||
attn_bias = attn_bias.add_bias(bias)
|
||||
|
||||
if padded:
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
|
||||
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
|
||||
|
||||
# shape: (b*s, n, d)
|
||||
if padded:
|
||||
out = out.squeeze(0)
|
||||
|
||||
return out
|
||||
|
||||
return mem_eff_attention
|
3
extensions/layernorm/__init__.py
Normal file
3
extensions/layernorm/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .layernorm_cuda import LayerNormCudaExtension
|
||||
|
||||
__all__ = ["LayerNormCudaExtension"]
|
24
extensions/layernorm/layernorm_cuda.py
Normal file
24
extensions/layernorm/layernorm_cuda.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from ..cuda_extension import _CudaExtension
|
||||
from ..utils import append_nvcc_threads, get_cuda_cc_flag
|
||||
|
||||
|
||||
class LayerNormCudaExtension(_CudaExtension):
|
||||
def __init__(self):
|
||||
super().__init__(name="layernorm_cuda")
|
||||
|
||||
def sources_files(self):
|
||||
ret = [self.csrc_abs_path(fname) for fname in ["cuda/layer_norm_cuda.cpp", "cuda/layer_norm_cuda_kernel.cu"]]
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
ret = [self.get_cuda_home_include()]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
return ["-O3"] + self.version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = ["-maxrregcount=50"]
|
||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros
|
||||
return append_nvcc_threads(ret)
|
3
extensions/moe/__init__.py
Normal file
3
extensions/moe/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .moe_cuda import MoeCudaExtension
|
||||
|
||||
__all__ = ['MoeCudaExtension']
|
29
extensions/moe/moe_cuda.py
Normal file
29
extensions/moe/moe_cuda.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from ..cuda_extension import _CudaExtension
|
||||
from ..utils import append_nvcc_threads, get_cuda_cc_flag
|
||||
|
||||
|
||||
class MoeCudaExtension(_CudaExtension):
|
||||
def __init__(self):
|
||||
super().__init__(name="moe_cuda")
|
||||
|
||||
def include_dirs(self):
|
||||
ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()]
|
||||
return ret
|
||||
|
||||
def sources_files(self):
|
||||
ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe_cuda.cpp", "cuda/moe_cuda_kernel.cu"]]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
return ["-O3"] + self.version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = [
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
]
|
||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags
|
||||
return append_nvcc_threads(ret)
|
3
extensions/optimizer/__init__.py
Normal file
3
extensions/optimizer/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .fused_optimizer_cuda import FusedOptimizerCudaExtension
|
||||
|
||||
__all__ = ['FusedOptimizerCudaExtension']
|
34
extensions/optimizer/fused_optimizer_cuda.py
Normal file
34
extensions/optimizer/fused_optimizer_cuda.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from ..cuda_extension import _CudaExtension
|
||||
from ..utils import get_cuda_cc_flag
|
||||
|
||||
|
||||
class FusedOptimizerCudaExtension(_CudaExtension):
|
||||
def __init__(self):
|
||||
super().__init__(name="fused_optim_cuda")
|
||||
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path(fname)
|
||||
for fname in [
|
||||
"cuda/colossal_C_frontend.cpp",
|
||||
"cuda/multi_tensor_sgd_kernel.cu",
|
||||
"cuda/multi_tensor_scale_kernel.cu",
|
||||
"cuda/multi_tensor_adam.cu",
|
||||
"cuda/multi_tensor_l2norm_kernel.cu",
|
||||
"cuda/multi_tensor_lamb.cu",
|
||||
]
|
||||
]
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
ret = [self.get_cuda_home_include()]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
|
||||
return ["-O3"] + version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = ["-lineinfo"]
|
||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||
return ["-O3", "--use_fast_math"] + extra_cuda_flags
|
4
extensions/softmax/__init__.py
Normal file
4
extensions/softmax/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension
|
||||
from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension
|
||||
|
||||
__all__ = ['ScaledMaskedSoftmaxCudaExtension', 'ScaledUpperTriangleMaskedSoftmaxCudaExtension']
|
32
extensions/softmax/scaled_masked_softmax_cuda.py
Normal file
32
extensions/softmax/scaled_masked_softmax_cuda.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from ..cuda_extension import _CudaExtension
|
||||
from ..utils import append_nvcc_threads
|
||||
|
||||
|
||||
class ScaledMaskedSoftmaxCudaExtension(_CudaExtension):
|
||||
def __init__(self):
|
||||
super().__init__(name="scaled_masked_softmax_cuda")
|
||||
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path(fname)
|
||||
for fname in ["cuda/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_cuda.cu"]
|
||||
]
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
return [self.get_cuda_home_include()]
|
||||
|
||||
def cxx_flags(self):
|
||||
return ["-O3"] + self.version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = [
|
||||
"-std=c++14",
|
||||
"-std=c++17",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
|
||||
]
|
||||
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
|
||||
return append_nvcc_threads(ret)
|
@@ -0,0 +1,34 @@
|
||||
from ..cuda_extension import _CudaExtension
|
||||
from ..utils import append_nvcc_threads, get_cuda_cc_flag
|
||||
|
||||
|
||||
class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension):
|
||||
def __init__(self):
|
||||
super().__init__(name="scaled_upper_triangle_masked_softmax_cuda")
|
||||
|
||||
def include_dirs(self):
|
||||
return [self.get_cuda_home_include()]
|
||||
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path(fname)
|
||||
for fname in [
|
||||
"cuda/scaled_upper_triang_masked_softmax.cpp",
|
||||
"cuda/scaled_upper_triang_masked_softmax_cuda.cu",
|
||||
]
|
||||
]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
return ["-O3"] + self.version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = [
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
]
|
||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags
|
||||
return append_nvcc_threads(ret)
|
21
extensions/triton_extension.py
Normal file
21
extensions/triton_extension.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from .base_extension import _Extension
|
||||
|
||||
__all__ = ["_TritonExtension"]
|
||||
|
||||
|
||||
class _TritonExtension(_Extension):
|
||||
def __init__(self, name: str, priority: int = 1):
|
||||
super().__init__(name, support_aot=False, support_jit=True, priority=priority)
|
||||
|
||||
def is_hardware_compatible(self) -> bool:
|
||||
# cuda extension can only be built if cuda is availabe
|
||||
try:
|
||||
import torch
|
||||
|
||||
cuda_available = torch.cuda.is_available()
|
||||
except:
|
||||
cuda_available = False
|
||||
return cuda_available
|
||||
|
||||
def load(self):
|
||||
return self.build_jit()
|
229
extensions/utils.py
Normal file
229
extensions/utils.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
|
||||
def print_rank_0(message: str) -> None:
|
||||
"""
|
||||
Print on only one process to avoid spamming.
|
||||
"""
|
||||
try:
|
||||
import torch.distributed as dist
|
||||
|
||||
if not dist.is_initialized():
|
||||
is_main_rank = True
|
||||
else:
|
||||
is_main_rank = dist.get_rank() == 0
|
||||
except ImportError:
|
||||
is_main_rank = True
|
||||
|
||||
if is_main_rank:
|
||||
print(message)
|
||||
|
||||
|
||||
def get_cuda_version_in_pytorch() -> List[int]:
|
||||
"""
|
||||
This function returns the CUDA version in the PyTorch build.
|
||||
|
||||
Returns:
|
||||
The CUDA version required by PyTorch, in the form of tuple (major, minor).
|
||||
"""
|
||||
import torch
|
||||
|
||||
try:
|
||||
torch_cuda_major = torch.version.cuda.split(".")[0]
|
||||
torch_cuda_minor = torch.version.cuda.split(".")[1]
|
||||
except:
|
||||
raise ValueError(
|
||||
"[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda"
|
||||
)
|
||||
return torch_cuda_major, torch_cuda_minor
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir) -> List[int]:
|
||||
"""
|
||||
Get the System CUDA version from nvcc.
|
||||
|
||||
Args:
|
||||
cuda_dir (str): the directory for CUDA Toolkit.
|
||||
|
||||
Returns:
|
||||
The CUDA version required by PyTorch, in the form of tuple (major, minor).
|
||||
"""
|
||||
nvcc_path = os.path.join(cuda_dir, "bin/nvcc")
|
||||
|
||||
if cuda_dir is None:
|
||||
raise ValueError(
|
||||
f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly."
|
||||
)
|
||||
|
||||
# check for nvcc path
|
||||
if not os.path.exists(nvcc_path):
|
||||
raise FileNotFoundError(
|
||||
f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME."
|
||||
)
|
||||
|
||||
# parse the nvcc -v output to obtain the system cuda version
|
||||
try:
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
except:
|
||||
raise ValueError(
|
||||
f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}"
|
||||
)
|
||||
|
||||
return bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def check_system_pytorch_cuda_match(cuda_dir):
|
||||
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch()
|
||||
|
||||
if bare_metal_major != torch_cuda_major:
|
||||
raise Exception(
|
||||
f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) "
|
||||
f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})."
|
||||
"Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ."
|
||||
)
|
||||
|
||||
if bare_metal_minor != torch_cuda_minor:
|
||||
warnings.warn(
|
||||
f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
|
||||
"The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. "
|
||||
"If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def get_pytorch_version() -> List[int]:
|
||||
"""
|
||||
This functions finds the PyTorch version.
|
||||
|
||||
Returns:
|
||||
A tuple of integers in the form of (major, minor, patch).
|
||||
"""
|
||||
import torch
|
||||
|
||||
torch_version = torch.__version__.split("+")[0]
|
||||
TORCH_MAJOR = int(torch_version.split(".")[0])
|
||||
TORCH_MINOR = int(torch_version.split(".")[1])
|
||||
TORCH_PATCH = int(torch_version.split(".")[2], 16)
|
||||
return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH
|
||||
|
||||
|
||||
def check_pytorch_version(min_major_version, min_minor_version) -> bool:
|
||||
"""
|
||||
Compare the current PyTorch version with the minium required version.
|
||||
|
||||
Args:
|
||||
min_major_version (int): the minimum major version of PyTorch required
|
||||
min_minor_version (int): the minimum minor version of PyTorch required
|
||||
|
||||
Returns:
|
||||
A boolean value. The value is True if the current pytorch version is acceptable and False otherwise.
|
||||
"""
|
||||
# get pytorch version
|
||||
torch_major, torch_minor, _ = get_pytorch_version()
|
||||
|
||||
# if the
|
||||
if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version):
|
||||
raise RuntimeError(
|
||||
f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n"
|
||||
"The latest stable release can be obtained from https://pytorch.org/get-started/locally/"
|
||||
)
|
||||
|
||||
|
||||
def check_cuda_availability():
|
||||
"""
|
||||
Check if CUDA is available on the system.
|
||||
|
||||
Returns:
|
||||
A boolean value. True if CUDA is available and False otherwise.
|
||||
"""
|
||||
import torch
|
||||
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
def set_cuda_arch_list(cuda_dir):
|
||||
"""
|
||||
This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation.
|
||||
Ahead-of-time compilation occurs when CUDA_EXT=1 is set when running 'pip install'.
|
||||
"""
|
||||
cuda_available = check_cuda_availability()
|
||||
|
||||
# we only need to set this when CUDA is not available for cross-compilation
|
||||
if not cuda_available:
|
||||
warnings.warn(
|
||||
"\n[extension] PyTorch did not find available GPUs on this system.\n"
|
||||
"If your intention is to cross-compile, this is not an error.\n"
|
||||
"By default, Colossal-AI will cross-compile for \n"
|
||||
"1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
|
||||
"2. Volta (compute capability 7.0)\n"
|
||||
"3. Turing (compute capability 7.5),\n"
|
||||
"4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n"
|
||||
"\nIf you wish to cross-compile for a single specific architecture,\n"
|
||||
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n'
|
||||
)
|
||||
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
||||
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||
|
||||
arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"]
|
||||
|
||||
if int(bare_metal_major) == 11:
|
||||
if int(bare_metal_minor) == 0:
|
||||
arch_list.append("8.0")
|
||||
else:
|
||||
arch_list.append("8.0")
|
||||
arch_list.append("8.6")
|
||||
|
||||
arch_list_str = ";".join(arch_list)
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_cuda_cc_flag() -> List[str]:
|
||||
"""
|
||||
This function produces the cc flags for your GPU arch
|
||||
|
||||
Returns:
|
||||
The CUDA cc flags for compilation.
|
||||
"""
|
||||
|
||||
# only import torch when needed
|
||||
# this is to avoid importing torch when building on a machine without torch pre-installed
|
||||
# one case is to build wheel for pypi release
|
||||
import torch
|
||||
|
||||
cc_flag = []
|
||||
max_arch = "".join(str(i) for i in torch.cuda.get_device_capability())
|
||||
for arch in torch.cuda.get_arch_list():
|
||||
res = re.search(r"sm_(\d+)", arch)
|
||||
if res:
|
||||
arch_cap = res[1]
|
||||
if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch):
|
||||
cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"])
|
||||
return cc_flag
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:
|
||||
"""
|
||||
This function appends the threads flag to your nvcc args.
|
||||
|
||||
Returns:
|
||||
The nvcc compilation flags including the threads flag.
|
||||
"""
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
Reference in New Issue
Block a user