mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-04 06:28:05 +00:00
[buider] use builder() for cpu adam and fused optim in setup.py (#2187)
This commit is contained in:
parent
d42afd30f8
commit
bc0e271e71
@ -1,8 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from .builder import Builder
|
from .builder import Builder
|
||||||
|
from .utils import append_nvcc_threads
|
||||||
|
|
||||||
|
|
||||||
class CPUAdamBuilder(Builder):
|
class CPUAdamBuilder(Builder):
|
||||||
@ -28,37 +27,35 @@ class CPUAdamBuilder(Builder):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def include_paths(self):
|
def include_paths(self):
|
||||||
import torch
|
|
||||||
from torch.utils.cpp_extension import CUDA_HOME
|
from torch.utils.cpp_extension import CUDA_HOME
|
||||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||||
return [os.path.join(CPUAdamBuilder.BASE_DIR, "includes"), cuda_include]
|
return [os.path.join(CPUAdamBuilder.BASE_DIR, "includes"), cuda_include]
|
||||||
|
|
||||||
def colossalai_src_path(self, code_path):
|
|
||||||
if os.path.isabs(code_path):
|
|
||||||
return code_path
|
|
||||||
else:
|
|
||||||
return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
|
|
||||||
|
|
||||||
def strip_empty_entries(self, args):
|
def strip_empty_entries(self, args):
|
||||||
'''
|
'''
|
||||||
Drop any empty strings from the list of compile and link flags
|
Drop any empty strings from the list of compile and link flags
|
||||||
'''
|
'''
|
||||||
return [x for x in args if len(x) > 0]
|
return [x for x in args if len(x) > 0]
|
||||||
|
|
||||||
def builder(self):
|
def builder(self, name) -> 'CUDAExtension':
|
||||||
|
"""
|
||||||
|
get a CUDAExtension instance used for setup.py
|
||||||
|
"""
|
||||||
from torch.utils.cpp_extension import CUDAExtension
|
from torch.utils.cpp_extension import CUDAExtension
|
||||||
|
|
||||||
return CUDAExtension(
|
return CUDAExtension(
|
||||||
name=self.name,
|
name=name,
|
||||||
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources],
|
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources],
|
||||||
include_dirs=self.extra_include_paths,
|
include_dirs=self.extra_include_paths,
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
'cxx': ['-O3'] + self.version_dependent_macros + self.extra_cxx_flags,
|
'cxx': ['-O3'] + self.version_dependent_macros + self.extra_cuda_flags,
|
||||||
'nvcc': ['-O3', '--use_fast_math'] + self.extra_cuda_flags
|
'nvcc':
|
||||||
|
append_nvcc_threads(['-O3', '--use_fast_math'] + self.version_dependent_macros +
|
||||||
|
self.extra_cuda_flags)
|
||||||
})
|
})
|
||||||
|
|
||||||
def load(self, verbose=True):
|
def load(self, verbose=True):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
load and compile cpu_adam lib at runtime
|
load and compile cpu_adam lib at runtime
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -7,7 +7,7 @@ from .builder import Builder
|
|||||||
|
|
||||||
|
|
||||||
class FusedOptimBuilder(Builder):
|
class FusedOptimBuilder(Builder):
|
||||||
NAME = "fused_optim"
|
NAME = 'fused_optim'
|
||||||
BASE_DIR = "cuda_native/csrc"
|
BASE_DIR = "cuda_native/csrc"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -41,10 +41,10 @@ class FusedOptimBuilder(Builder):
|
|||||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||||
return [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), cuda_include]
|
return [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), cuda_include]
|
||||||
|
|
||||||
def builder(self):
|
def builder(self, name):
|
||||||
from torch.utils.cpp_extension import CUDAExtension
|
from torch.utils.cpp_extension import CUDAExtension
|
||||||
return CUDAExtension(
|
return CUDAExtension(
|
||||||
name=self.name,
|
name=name,
|
||||||
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources],
|
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources],
|
||||||
include_dirs=self.extra_include_paths,
|
include_dirs=self.extra_include_paths,
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
|
20
colossalai/kernel/op_builder/utils.py
Normal file
20
colossalai/kernel/op_builder/utils.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_bare_metal_version(cuda_dir):
|
||||||
|
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||||
|
output = raw_output.split()
|
||||||
|
release_idx = output.index("release") + 1
|
||||||
|
release = output[release_idx].split(".")
|
||||||
|
bare_metal_major = release[0]
|
||||||
|
bare_metal_minor = release[1][0]
|
||||||
|
|
||||||
|
return raw_output, bare_metal_major, bare_metal_minor
|
||||||
|
|
||||||
|
|
||||||
|
def append_nvcc_threads(nvcc_extra_args):
|
||||||
|
from torch.utils.cpp_extension import CUDA_HOME
|
||||||
|
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||||
|
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||||
|
return nvcc_extra_args + ["--threads", "4"]
|
||||||
|
return nvcc_extra_args
|
@ -7,3 +7,4 @@ rich
|
|||||||
click
|
click
|
||||||
fabric
|
fabric
|
||||||
contexttimer
|
contexttimer
|
||||||
|
ninja
|
||||||
|
32
setup.py
32
setup.py
@ -1,9 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
|
||||||
|
|
||||||
from setuptools import Extension, find_packages, setup
|
from setuptools import Extension, find_packages, setup
|
||||||
|
|
||||||
|
from colossalai.kernel.op_builder.utils import get_cuda_bare_metal_version
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
|
||||||
@ -26,17 +27,6 @@ if int(os.environ.get('NO_CUDA_EXT', '0')) == 1:
|
|||||||
build_cuda_ext = False
|
build_cuda_ext = False
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_bare_metal_version(cuda_dir):
|
|
||||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
|
||||||
output = raw_output.split()
|
|
||||||
release_idx = output.index("release") + 1
|
|
||||||
release = output[release_idx].split(".")
|
|
||||||
bare_metal_major = release[0]
|
|
||||||
bare_metal_minor = release[1][0]
|
|
||||||
|
|
||||||
return raw_output, bare_metal_major, bare_metal_minor
|
|
||||||
|
|
||||||
|
|
||||||
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
|
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
|
||||||
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||||
torch_binary_major = torch.version.cuda.split(".")[0]
|
torch_binary_major = torch.version.cuda.split(".")[0]
|
||||||
@ -146,6 +136,11 @@ if build_cuda_ext:
|
|||||||
'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags)
|
'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#### fused optim kernels ###
|
||||||
|
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||||
|
ext_modules.append(FusedOptimBuilder().builder('colossalai._C.fused_optim'))
|
||||||
|
|
||||||
|
#### N-D parallel kernels ###
|
||||||
cc_flag = []
|
cc_flag = []
|
||||||
for arch in torch.cuda.get_arch_list():
|
for arch in torch.cuda.get_arch_list():
|
||||||
res = re.search(r'sm_(\d+)', arch)
|
res = re.search(r'sm_(\d+)', arch)
|
||||||
@ -154,14 +149,6 @@ if build_cuda_ext:
|
|||||||
if int(arch_cap) >= 60:
|
if int(arch_cap) >= 60:
|
||||||
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
|
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
|
||||||
|
|
||||||
extra_cuda_flags = ['-lineinfo']
|
|
||||||
|
|
||||||
ext_modules.append(
|
|
||||||
cuda_ext_helper('colossalai._C.fused_optim', [
|
|
||||||
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
|
|
||||||
'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu'
|
|
||||||
], extra_cuda_flags + cc_flag))
|
|
||||||
|
|
||||||
extra_cuda_flags = [
|
extra_cuda_flags = [
|
||||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr',
|
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr',
|
||||||
'--expt-extended-lambda'
|
'--expt-extended-lambda'
|
||||||
@ -197,8 +184,9 @@ if build_cuda_ext:
|
|||||||
'kernels/general_kernels.cu', 'kernels/cuda_util.cu'
|
'kernels/general_kernels.cu', 'kernels/cuda_util.cu'
|
||||||
], extra_cuda_flags + cc_flag))
|
], extra_cuda_flags + cc_flag))
|
||||||
|
|
||||||
extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native']
|
### Gemini Adam kernel ####
|
||||||
ext_modules.append(cuda_ext_helper('colossalai._C.cpu_optim', ['cpu_adam.cpp'], extra_cuda_flags, extra_cxx_flags))
|
from colossalai.kernel.op_builder import CPUAdamBuilder
|
||||||
|
ext_modules.append(CPUAdamBuilder().builder('colossalai._C.cpu_optim'))
|
||||||
|
|
||||||
setup(name='colossalai',
|
setup(name='colossalai',
|
||||||
version=get_version(),
|
version=get_version(),
|
||||||
|
@ -67,15 +67,14 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
|||||||
exp_avg_sq_copy = exp_avg_sq.clone()
|
exp_avg_sq_copy = exp_avg_sq.clone()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import colossalai._C.cpu_optim
|
from colossalai._C import cpu_optim
|
||||||
cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
|
||||||
print("use prebuilt CPUAdamOptimizer")
|
|
||||||
except:
|
except:
|
||||||
from colossalai.kernel.op_builder.cpu_adam import CPUAdamBuilder
|
from colossalai.kernel.op_builder import CPUAdamBuilder
|
||||||
lib = CPUAdamBuilder().load()
|
cpu_optim = CPUAdamBuilder().load()
|
||||||
cpu_adam_op = lib.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
|
||||||
print("build CPUAdamOptimizer at runtime")
|
print("build CPUAdamOptimizer at runtime")
|
||||||
|
|
||||||
|
cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||||
|
|
||||||
cpu_adam_op.step(
|
cpu_adam_op.step(
|
||||||
step,
|
step,
|
||||||
lr,
|
lr,
|
||||||
|
Loading…
Reference in New Issue
Block a user