mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 15:32:22 +00:00
[npu] add npu support for gemini and zero (#5067)
* [npu] setup device utils (#5047) * [npu] add npu device support * [npu] support low level zero * [test] update npu zero plugin test * [hotfix] fix import * [test] recover tests * [npu] gemini support npu (#5052) * [npu] refactor device utils * [gemini] support npu * [example] llama2+gemini support npu * [kernel] add arm cpu adam kernel (#5065) * [kernel] add arm cpu adam * [optim] update adam optimizer * [kernel] arm cpu adam remove bf16 support
This commit is contained in:
parent
8d56c9c389
commit
e5ce4c8ea6
@ -8,6 +8,7 @@ import torch
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
__all__ = ["BaseGradScaler"]
|
__all__ = ["BaseGradScaler"]
|
||||||
|
|
||||||
@ -22,7 +23,7 @@ class BaseGradScaler(ABC):
|
|||||||
|
|
||||||
def __init__(self, initial_scale: float, verbose: bool):
|
def __init__(self, initial_scale: float, verbose: bool):
|
||||||
assert initial_scale > 0
|
assert initial_scale > 0
|
||||||
self._scale = torch.cuda.FloatTensor([initial_scale])
|
self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float)
|
||||||
self._verbose = verbose
|
self._verbose = verbose
|
||||||
|
|
||||||
if self._verbose:
|
if self._verbose:
|
||||||
|
@ -5,6 +5,8 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from .base_grad_scaler import BaseGradScaler
|
from .base_grad_scaler import BaseGradScaler
|
||||||
|
|
||||||
__all__ = ["DynamicGradScaler"]
|
__all__ = ["DynamicGradScaler"]
|
||||||
@ -37,12 +39,12 @@ class DynamicGradScaler(BaseGradScaler):
|
|||||||
):
|
):
|
||||||
super().__init__(initial_scale, verbose)
|
super().__init__(initial_scale, verbose)
|
||||||
if min_scale:
|
if min_scale:
|
||||||
self._min_scale = torch.cuda.FloatTensor([min_scale])
|
self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float)
|
||||||
else:
|
else:
|
||||||
self._min_scale = None
|
self._min_scale = None
|
||||||
|
|
||||||
if max_scale:
|
if max_scale:
|
||||||
self._max_scale = torch.cuda.FloatTensor([max_scale])
|
self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float)
|
||||||
else:
|
else:
|
||||||
self._max_scale = None
|
self._max_scale = None
|
||||||
|
|
||||||
@ -115,7 +117,7 @@ class DynamicGradScaler(BaseGradScaler):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
|
self._scale = state_dict["scale"].to(get_current_device())
|
||||||
self._growth_factor = state_dict["growth_factor"]
|
self._growth_factor = state_dict["growth_factor"]
|
||||||
self._backoff_factor = state_dict["backoff_factor"]
|
self._backoff_factor = state_dict["backoff_factor"]
|
||||||
self._hysteresis = state_dict["hysteresis"]
|
self._hysteresis = state_dict["hysteresis"]
|
||||||
|
@ -11,7 +11,7 @@ except:
|
|||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from .region import Region
|
from .region import Region
|
||||||
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
|
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
|
||||||
|
@ -25,6 +25,7 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
|||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.utils.device import IS_NPU_AVAILABLE
|
||||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||||
|
|
||||||
@ -37,6 +38,7 @@ PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
|
|||||||
|
|
||||||
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
|
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
|
||||||
|
|
||||||
|
|
||||||
def get_param_info(optim: Optimizer):
|
def get_param_info(optim: Optimizer):
|
||||||
# Get a backup of necessary information of parameters for future use, which includes:
|
# Get a backup of necessary information of parameters for future use, which includes:
|
||||||
# 1. A mapping from integer param_id to param32 shape.
|
# 1. A mapping from integer param_id to param32 shape.
|
||||||
@ -53,6 +55,8 @@ def get_param_info(optim: Optimizer):
|
|||||||
start_index += len(group["params"])
|
start_index += len(group["params"])
|
||||||
|
|
||||||
return param_info
|
return param_info
|
||||||
|
|
||||||
|
|
||||||
class GeminiCheckpointIO(GeneralCheckpointIO):
|
class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -359,6 +363,8 @@ class GeminiPlugin(DPPluginBase):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
assert placement_policy == "static", "NPU only supports static placement policy"
|
||||||
self.gemini_config = dict(
|
self.gemini_config = dict(
|
||||||
chunk_config_dict=chunk_config_dict,
|
chunk_config_dict=chunk_config_dict,
|
||||||
chunk_init_device=(chunk_init_device or get_current_device()),
|
chunk_init_device=(chunk_init_device or get_current_device()),
|
||||||
@ -437,7 +443,7 @@ class GeminiPlugin(DPPluginBase):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def supported_devices(self) -> List[str]:
|
def supported_devices(self) -> List[str]:
|
||||||
return ["cuda"]
|
return ["cuda", "npu"]
|
||||||
|
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
|
@ -306,7 +306,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def supported_devices(self) -> List[str]:
|
def supported_devices(self) -> List[str]:
|
||||||
return ["cuda"]
|
return ["cuda", "npu"]
|
||||||
|
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
|
@ -11,7 +11,7 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
from colossalai.context import Config
|
from colossalai.context import Config
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import set_device, set_seed
|
from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed
|
||||||
|
|
||||||
|
|
||||||
def launch(
|
def launch(
|
||||||
@ -47,12 +47,15 @@ def launch(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
warnings.warn("`config` is deprecated and will be removed soon.")
|
warnings.warn("`config` is deprecated and will be removed soon.")
|
||||||
|
|
||||||
|
if IS_NPU_AVAILABLE and backend == "nccl":
|
||||||
|
backend = "hccl"
|
||||||
|
|
||||||
# init default process group
|
# init default process group
|
||||||
init_method = f"tcp://[{host}]:{port}"
|
init_method = f"tcp://[{host}]:{port}"
|
||||||
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||||
|
|
||||||
# set cuda device
|
# set cuda device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available() or IS_NPU_AVAILABLE:
|
||||||
# if local rank is not given, calculate automatically
|
# if local rank is not given, calculate automatically
|
||||||
set_device(local_rank)
|
set_device(local_rank)
|
||||||
|
|
||||||
|
@ -142,6 +142,7 @@ class Adam_Optimizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||||
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
|
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
|
||||||
AVX_Data &data) {
|
AVX_Data &data) {
|
||||||
if (is_half) {
|
if (is_half) {
|
||||||
@ -159,6 +160,7 @@ class Adam_Optimizer {
|
|||||||
SIMD_STORE(ptr, data.data);
|
SIMD_STORE(ptr, data.data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
|
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
|
||||||
float weight_decay, bool bias_correction, torch::Tensor ¶ms,
|
float weight_decay, bool bias_correction, torch::Tensor ¶ms,
|
||||||
|
304
colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp
Normal file
304
colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
#include "cpu_adam_arm.h"
|
||||||
|
|
||||||
|
void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg,
|
||||||
|
void *_exp_avg_sq, size_t _param_size,
|
||||||
|
at::ScalarType param_dtype,
|
||||||
|
at::ScalarType grad_dtype,
|
||||||
|
at::ScalarType exp_avg_dtype,
|
||||||
|
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
|
||||||
|
size_t rounded_size = 0;
|
||||||
|
#if defined(__aarch64__)
|
||||||
|
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
float betta1_minus1 = 1 - _betta1;
|
||||||
|
float betta2_minus1 = 1 - _betta2;
|
||||||
|
float step_size = -1 * _alpha / _bias_correction1;
|
||||||
|
float w_decay = -1 * _alpha * _weight_decay;
|
||||||
|
|
||||||
|
#if defined(__aarch64__)
|
||||||
|
float32x4_t betta1_4 = simd_set(_betta1);
|
||||||
|
float32x4_t betta2_4 = simd_set(_betta2);
|
||||||
|
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
|
||||||
|
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
|
||||||
|
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
|
||||||
|
float32x4_t eps_4 = simd_set(_eps);
|
||||||
|
float32x4_t step_size_4 = simd_set(step_size);
|
||||||
|
float32x4_t weight_decay_4;
|
||||||
|
if (_weight_decay > 0) {
|
||||||
|
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
|
||||||
|
}
|
||||||
|
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||||
|
size_t copy_size = TILE;
|
||||||
|
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||||
|
size_t offset = copy_size + t;
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
|
||||||
|
float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i);
|
||||||
|
if (loss_scale > 0) {
|
||||||
|
float32x4_t loss_scale_vec = simd_set(loss_scale);
|
||||||
|
grad_4 = vdivq_f32(grad_4, loss_scale_vec);
|
||||||
|
}
|
||||||
|
float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i);
|
||||||
|
float32x4_t variance_4 =
|
||||||
|
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i);
|
||||||
|
float32x4_t param_4 = simd_load_offset(_params, param_dtype, i);
|
||||||
|
if (_weight_decay > 0 && !_adamw_mode) {
|
||||||
|
grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4);
|
||||||
|
}
|
||||||
|
momentum_4 = vmulq_f32(momentum_4, betta1_4);
|
||||||
|
momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4);
|
||||||
|
variance_4 = vmulq_f32(variance_4, betta2_4);
|
||||||
|
grad_4 = vmulq_f32(grad_4, grad_4);
|
||||||
|
variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4);
|
||||||
|
grad_4 = vsqrtq_f32(variance_4);
|
||||||
|
grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt);
|
||||||
|
grad_4 = vdivq_f32(momentum_4, grad_4);
|
||||||
|
if (_weight_decay > 0 && _adamw_mode) {
|
||||||
|
param_4 = vfmaq_f32(param_4, param_4, weight_decay_4);
|
||||||
|
}
|
||||||
|
param_4 = vfmaq_f32(param_4, grad_4, step_size_4);
|
||||||
|
simd_store_offset(_params, param_dtype, param_4, i);
|
||||||
|
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i);
|
||||||
|
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
if (_param_size > rounded_size) {
|
||||||
|
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||||
|
size_t copy_size = TILE;
|
||||||
|
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||||
|
size_t offset = copy_size + t;
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (size_t k = t; k < offset; k++) {
|
||||||
|
float grad = scalar_load_offset(grads, grad_dtype, k);
|
||||||
|
if (loss_scale > 0) {
|
||||||
|
grad /= loss_scale;
|
||||||
|
}
|
||||||
|
float param = scalar_load_offset(_params, param_dtype, k);
|
||||||
|
float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k);
|
||||||
|
float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k);
|
||||||
|
if (_weight_decay > 0 && !_adamw_mode) {
|
||||||
|
grad = param * _weight_decay + grad;
|
||||||
|
}
|
||||||
|
momentum = momentum * _betta1;
|
||||||
|
momentum = grad * betta1_minus1 + momentum;
|
||||||
|
|
||||||
|
variance = variance * _betta2;
|
||||||
|
grad = grad * grad;
|
||||||
|
variance = grad * betta2_minus1 + variance;
|
||||||
|
|
||||||
|
grad = sqrt(variance);
|
||||||
|
grad = grad * _bias_correction2 + _eps;
|
||||||
|
grad = momentum / grad;
|
||||||
|
if (_weight_decay > 0 && _adamw_mode) {
|
||||||
|
param += w_decay * param;
|
||||||
|
}
|
||||||
|
param = grad * step_size + param;
|
||||||
|
|
||||||
|
scalar_store_offset(_params, param_dtype, param, k);
|
||||||
|
scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k);
|
||||||
|
scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg,
|
||||||
|
void *_exp_avg_sq, size_t _param_size,
|
||||||
|
at::ScalarType param_dtype,
|
||||||
|
at::ScalarType grad_dtype,
|
||||||
|
at::ScalarType exp_avg_dtype,
|
||||||
|
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
|
||||||
|
size_t rounded_size = 0;
|
||||||
|
#if defined(__aarch64__)
|
||||||
|
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
float betta1_minus1 = 1 - _betta1;
|
||||||
|
float betta2_minus1 = 1 - _betta2;
|
||||||
|
float step_size = -1 * _alpha / _bias_correction1;
|
||||||
|
float w_decay = -1 * _alpha * _weight_decay;
|
||||||
|
|
||||||
|
#if defined(__aarch64__)
|
||||||
|
float32x4_t betta1_4 = simd_set(_betta1);
|
||||||
|
float32x4_t betta2_4 = simd_set(_betta2);
|
||||||
|
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
|
||||||
|
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
|
||||||
|
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
|
||||||
|
float32x4_t eps_4 = simd_set(_eps);
|
||||||
|
float32x4_t step_size_4 = simd_set(step_size);
|
||||||
|
float32x4_t weight_decay_4;
|
||||||
|
if (_weight_decay > 0) {
|
||||||
|
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||||
|
size_t copy_size = TILE;
|
||||||
|
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||||
|
size_t offset = copy_size + t;
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
|
||||||
|
float32x4_t grad_4[4];
|
||||||
|
float32x4_t momentum_4[4];
|
||||||
|
float32x4_t variance_4[4];
|
||||||
|
float32x4_t param_4[4];
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int j = 0; j < 4; j++) {
|
||||||
|
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
|
||||||
|
if (loss_scale > 0) {
|
||||||
|
float32x4_t loss_scale_vec = simd_set(loss_scale);
|
||||||
|
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
|
||||||
|
}
|
||||||
|
momentum_4[j] =
|
||||||
|
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
|
||||||
|
variance_4[j] =
|
||||||
|
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
|
||||||
|
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
|
||||||
|
if (_weight_decay > 0 && !_adamw_mode) {
|
||||||
|
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
|
||||||
|
}
|
||||||
|
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
|
||||||
|
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
|
||||||
|
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
|
||||||
|
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
|
||||||
|
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
|
||||||
|
grad_4[j] = vsqrtq_f32(variance_4[j]);
|
||||||
|
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
|
||||||
|
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
|
||||||
|
if (_weight_decay > 0 && _adamw_mode) {
|
||||||
|
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
|
||||||
|
}
|
||||||
|
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
|
||||||
|
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
|
||||||
|
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
|
||||||
|
i + SIMD_WIDTH * j);
|
||||||
|
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
|
||||||
|
i + SIMD_WIDTH * j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
if (_param_size > rounded_size) {
|
||||||
|
Step_1(scalar_seek_offset(_params, param_dtype, rounded_size),
|
||||||
|
scalar_seek_offset(grads, grad_dtype, rounded_size),
|
||||||
|
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
|
||||||
|
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
|
||||||
|
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
|
||||||
|
exp_avg_sq_dtype, loss_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg,
|
||||||
|
void *_exp_avg_sq, size_t _param_size,
|
||||||
|
at::ScalarType param_dtype,
|
||||||
|
at::ScalarType grad_dtype,
|
||||||
|
at::ScalarType exp_avg_dtype,
|
||||||
|
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
|
||||||
|
size_t rounded_size = 0;
|
||||||
|
#if defined(__aarch64__)
|
||||||
|
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
float betta1_minus1 = 1 - _betta1;
|
||||||
|
float betta2_minus1 = 1 - _betta2;
|
||||||
|
float step_size = -1 * _alpha / _bias_correction1;
|
||||||
|
float w_decay = -1 * _alpha * _weight_decay;
|
||||||
|
#if defined(__aarch64__)
|
||||||
|
float32x4_t betta1_4 = simd_set(_betta1);
|
||||||
|
float32x4_t betta2_4 = simd_set(_betta2);
|
||||||
|
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
|
||||||
|
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
|
||||||
|
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
|
||||||
|
float32x4_t eps_4 = simd_set(_eps);
|
||||||
|
float32x4_t step_size_4 = simd_set(step_size);
|
||||||
|
float32x4_t weight_decay_4;
|
||||||
|
if (_weight_decay > 0) {
|
||||||
|
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||||
|
size_t copy_size = TILE;
|
||||||
|
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||||
|
size_t offset = copy_size + t;
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
|
||||||
|
float32x4_t grad_4[8];
|
||||||
|
float32x4_t momentum_4[8];
|
||||||
|
float32x4_t variance_4[8];
|
||||||
|
float32x4_t param_4[8];
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
|
||||||
|
if (loss_scale > 0) {
|
||||||
|
float32x4_t loss_scale_vec = simd_set(loss_scale);
|
||||||
|
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
|
||||||
|
}
|
||||||
|
momentum_4[j] =
|
||||||
|
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
|
||||||
|
variance_4[j] =
|
||||||
|
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
|
||||||
|
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
|
||||||
|
if (_weight_decay > 0 && !_adamw_mode) {
|
||||||
|
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
|
||||||
|
}
|
||||||
|
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
|
||||||
|
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
|
||||||
|
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
|
||||||
|
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
|
||||||
|
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
|
||||||
|
grad_4[j] = vsqrtq_f32(variance_4[j]);
|
||||||
|
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
|
||||||
|
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
|
||||||
|
if (_weight_decay > 0 && _adamw_mode) {
|
||||||
|
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
|
||||||
|
}
|
||||||
|
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
|
||||||
|
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
|
||||||
|
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
|
||||||
|
i + SIMD_WIDTH * j);
|
||||||
|
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
|
||||||
|
i + SIMD_WIDTH * j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
if (_param_size > rounded_size) {
|
||||||
|
Step_4(scalar_seek_offset(_params, param_dtype, rounded_size),
|
||||||
|
scalar_seek_offset(grads, grad_dtype, rounded_size),
|
||||||
|
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
|
||||||
|
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
|
||||||
|
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
|
||||||
|
exp_avg_sq_dtype, loss_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2,
|
||||||
|
float epsilon, float weight_decay,
|
||||||
|
bool bias_correction, torch::Tensor ¶ms,
|
||||||
|
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||||
|
torch::Tensor &exp_avg_sq, float loss_scale) {
|
||||||
|
auto params_c = params.contiguous();
|
||||||
|
auto grads_c = grads.contiguous();
|
||||||
|
auto exp_avg_c = exp_avg.contiguous();
|
||||||
|
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||||
|
|
||||||
|
this->IncrementStep(step, beta1, beta2);
|
||||||
|
this->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||||
|
this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(),
|
||||||
|
exp_avg_sq_c.data_ptr(), params_c.numel(),
|
||||||
|
params_c.scalar_type(), grads_c.scalar_type(),
|
||||||
|
exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
py::class_<AdamOptimizer>(m, "CPUAdamOptimizer")
|
||||||
|
.def(py::init<float, float, float, float, float, bool>())
|
||||||
|
.def("step", &AdamOptimizer::step);
|
||||||
|
}
|
201
colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h
Normal file
201
colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
||||||
|
#define TILE (128 * 1024 * 1024)
|
||||||
|
|
||||||
|
#if defined(__aarch64__)
|
||||||
|
#include <arm_neon.h>
|
||||||
|
#define SIMD_WIDTH 4
|
||||||
|
|
||||||
|
inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype,
|
||||||
|
size_t offset) {
|
||||||
|
switch (dtype) {
|
||||||
|
case at::ScalarType::Float: {
|
||||||
|
auto ptr_f = reinterpret_cast<const float32_t *>(ptr);
|
||||||
|
return vld1q_f32(ptr_f + offset);
|
||||||
|
}
|
||||||
|
case at::ScalarType::Half: {
|
||||||
|
auto ptr_h = reinterpret_cast<const float16_t *>(ptr);
|
||||||
|
return vcvt_f32_f16(vld1_f16(ptr_h + offset));
|
||||||
|
}
|
||||||
|
// case at::ScalarType::BFloat16: {
|
||||||
|
// auto ptr_b = reinterpret_cast<const bfloat16_t *>(ptr);
|
||||||
|
// return vcvt_f32_bf16(vld1_bf16(ptr_b + offset));
|
||||||
|
// }
|
||||||
|
default:
|
||||||
|
AT_ERROR("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) {
|
||||||
|
return simd_load_offset(ptr, dtype, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data,
|
||||||
|
size_t offset) {
|
||||||
|
switch (dtype) {
|
||||||
|
case at::ScalarType::Float: {
|
||||||
|
auto ptr_f = reinterpret_cast<float32_t *>(ptr);
|
||||||
|
vst1q_f32(ptr_f + offset, data);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case at::ScalarType::Half: {
|
||||||
|
auto ptr_h = reinterpret_cast<float16_t *>(ptr);
|
||||||
|
vst1_f16(ptr_h + offset, vcvt_f16_f32(data));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// case at::ScalarType::BFloat16: {
|
||||||
|
// auto ptr_b = reinterpret_cast<bfloat16_t *>(ptr);
|
||||||
|
// vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data));
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
default:
|
||||||
|
AT_ERROR("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) {
|
||||||
|
return simd_store_offset(ptr, dtype, data, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline float32x4_t simd_set(float value) {
|
||||||
|
auto val = static_cast<float32_t>(value);
|
||||||
|
return vdupq_n_f32(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
inline float scalar_load_offset(const void *ptr, at::ScalarType dtype,
|
||||||
|
size_t offset) {
|
||||||
|
switch (dtype) {
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
return *(reinterpret_cast<const float *>(ptr) + offset);
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
return static_cast<float>(
|
||||||
|
*(reinterpret_cast<const at::Half *>(ptr) + offset));
|
||||||
|
// case at::ScalarType::BFloat16:
|
||||||
|
// return static_cast<float>(
|
||||||
|
// *(reinterpret_cast<const at::BFloat16 *>(ptr) + offset));
|
||||||
|
default:
|
||||||
|
AT_ERROR("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data,
|
||||||
|
size_t offset) {
|
||||||
|
switch (dtype) {
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
*(reinterpret_cast<float *>(ptr) + offset) = data;
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
*(reinterpret_cast<at::Half *>(ptr) + offset) = data;
|
||||||
|
break;
|
||||||
|
// case at::ScalarType::BFloat16:
|
||||||
|
// *(reinterpret_cast<at::BFloat16 *>(ptr) + offset) = data;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
AT_ERROR("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype,
|
||||||
|
size_t offset) {
|
||||||
|
switch (dtype) {
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
return reinterpret_cast<float *>(ptr) + offset;
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
return reinterpret_cast<at::Half *>(ptr) + offset;
|
||||||
|
// case at::ScalarType::BFloat16:
|
||||||
|
// return reinterpret_cast<at::BFloat16 *>(ptr) + offset;
|
||||||
|
default:
|
||||||
|
AT_ERROR("Unsupported dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#define STEP(SPAN) \
|
||||||
|
void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \
|
||||||
|
void *_exp_avg_sq, size_t _param_size, \
|
||||||
|
at::ScalarType param_dtype, at::ScalarType grad_dtype, \
|
||||||
|
at::ScalarType exp_avg_dtype, \
|
||||||
|
at::ScalarType exp_avg_sq_dtype, float loss_scale = -1);
|
||||||
|
|
||||||
|
class AdamOptimizer {
|
||||||
|
private:
|
||||||
|
float _alpha;
|
||||||
|
float _betta1;
|
||||||
|
float _betta2;
|
||||||
|
float _eps;
|
||||||
|
float _weight_decay;
|
||||||
|
|
||||||
|
float _betta1_t;
|
||||||
|
float _betta2_t;
|
||||||
|
size_t _step;
|
||||||
|
|
||||||
|
float _bias_correction1;
|
||||||
|
float _bias_correction2;
|
||||||
|
|
||||||
|
bool _adamw_mode;
|
||||||
|
|
||||||
|
public:
|
||||||
|
AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
|
||||||
|
float eps = 1e-8, float weight_decay = 0,
|
||||||
|
bool adamw_mode = true)
|
||||||
|
: _alpha(alpha),
|
||||||
|
_betta1(betta1),
|
||||||
|
_betta2(betta2),
|
||||||
|
_eps(eps),
|
||||||
|
_weight_decay(weight_decay),
|
||||||
|
_betta1_t(1.0),
|
||||||
|
_betta2_t(1.0),
|
||||||
|
_step(0),
|
||||||
|
_adamw_mode(adamw_mode) {}
|
||||||
|
~AdamOptimizer() {}
|
||||||
|
|
||||||
|
STEP(1)
|
||||||
|
STEP(4)
|
||||||
|
STEP(8)
|
||||||
|
inline void IncrementStep(size_t step, float beta1, float beta2) {
|
||||||
|
if (beta1 != _betta1 || beta2 != _betta2) {
|
||||||
|
_step = step;
|
||||||
|
_betta1 = beta1;
|
||||||
|
_betta2 = beta2;
|
||||||
|
_betta1_t = std::pow(_betta1, step);
|
||||||
|
_betta2_t = std::pow(_betta2, step);
|
||||||
|
} else {
|
||||||
|
_step++;
|
||||||
|
if (_step != step) {
|
||||||
|
_betta1_t = std::pow(_betta1, step);
|
||||||
|
_betta2_t = std::pow(_betta2, step);
|
||||||
|
_step = step;
|
||||||
|
} else {
|
||||||
|
_betta1_t *= _betta1;
|
||||||
|
_betta2_t *= _betta2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inline void update_state(float lr, float epsilon, float weight_decay,
|
||||||
|
bool bias_correction) {
|
||||||
|
_alpha = lr;
|
||||||
|
_eps = epsilon;
|
||||||
|
_weight_decay = weight_decay;
|
||||||
|
|
||||||
|
_bias_correction1 = 1.0f;
|
||||||
|
_bias_correction2 = 1.0f;
|
||||||
|
if (bias_correction == 1) {
|
||||||
|
_bias_correction1 = 1 - _betta1_t;
|
||||||
|
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
|
||||||
|
float weight_decay, bool bias_correction, torch::Tensor ¶ms,
|
||||||
|
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||||
|
torch::Tensor &exp_avg_sq, float loss_scale);
|
||||||
|
};
|
@ -5,7 +5,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
|
|
||||||
class Unpad(torch.autograd.Function):
|
class Unpad(torch.autograd.Function):
|
||||||
|
@ -12,7 +12,7 @@ from colossalai.legacy.context.parallel_mode import ParallelMode
|
|||||||
from colossalai.legacy.core import global_context as gpc
|
from colossalai.legacy.core import global_context as gpc
|
||||||
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
|
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import colossalai.legacy.communication.p2p_v2 as comm
|
|||||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||||
from colossalai.legacy.core import global_context as gpc
|
from colossalai.legacy.core import global_context as gpc
|
||||||
from colossalai.legacy.engine import Engine
|
from colossalai.legacy.engine import Engine
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from ._pipeline_schedule import PipelineSchedule
|
from ._pipeline_schedule import PipelineSchedule
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ from colossalai.legacy.utils.checkpointing import (
|
|||||||
partition_tensor_parallel_state_dict,
|
partition_tensor_parallel_state_dict,
|
||||||
)
|
)
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from ..base_layer import ParallelLayer
|
from ..base_layer import ParallelLayer
|
||||||
from ..colossalai_layer._utils import ColossalaiModule
|
from ..colossalai_layer._utils import ColossalaiModule
|
||||||
|
@ -18,7 +18,7 @@ from colossalai.legacy.utils.checkpointing import (
|
|||||||
partition_tensor_parallel_state_dict,
|
partition_tensor_parallel_state_dict,
|
||||||
)
|
)
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from ..base_layer import ParallelLayer
|
from ..base_layer import ParallelLayer
|
||||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||||
|
@ -19,7 +19,7 @@ from colossalai.legacy.utils.checkpointing import (
|
|||||||
partition_tensor_parallel_state_dict,
|
partition_tensor_parallel_state_dict,
|
||||||
)
|
)
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from ..base_layer import ParallelLayer
|
from ..base_layer import ParallelLayer
|
||||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||||
|
@ -27,7 +27,7 @@ from colossalai.legacy.utils.checkpointing import (
|
|||||||
partition_tensor_parallel_state_dict,
|
partition_tensor_parallel_state_dict,
|
||||||
)
|
)
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||||
from ._operation import (
|
from ._operation import (
|
||||||
|
@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from colossalai.legacy.context import seed
|
from colossalai.legacy.context import seed
|
||||||
from colossalai.legacy.registry import LAYERS
|
from colossalai.legacy.registry import LAYERS
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from ..utils import to_2tuple
|
from ..utils import to_2tuple
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import types
|
|||||||
from time import time
|
from time import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from .stateful_tensor import StatefulTensor, TensorState
|
from .stateful_tensor import StatefulTensor, TensorState
|
||||||
from .tensor_placement_policy import TensorPlacementPolicy
|
from .tensor_placement_policy import TensorPlacementPolicy
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import math
|
import math
|
||||||
|
import platform
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.kernel.op_builder import CPUAdamBuilder
|
from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder
|
||||||
|
|
||||||
from .nvme_optimizer import NVMeOptimizer
|
from .nvme_optimizer import NVMeOptimizer
|
||||||
|
|
||||||
@ -77,7 +78,7 @@ class CPUAdam(NVMeOptimizer):
|
|||||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||||
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||||
self.adamw_mode = adamw_mode
|
self.adamw_mode = adamw_mode
|
||||||
cpu_adam = CPUAdamBuilder().load()
|
cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load()
|
||||||
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
|
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
|
||||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||||
|
|
||||||
|
@ -84,6 +84,7 @@ class HybridAdam(CPUAdam):
|
|||||||
nvme_offload_fraction,
|
nvme_offload_fraction,
|
||||||
nvme_offload_dir,
|
nvme_offload_dir,
|
||||||
)
|
)
|
||||||
|
if torch.cuda.is_available():
|
||||||
fused_optim = FusedOptimBuilder().load()
|
fused_optim = FusedOptimBuilder().load()
|
||||||
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
||||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
@ -118,11 +119,11 @@ class HybridAdam(CPUAdam):
|
|||||||
group_step = state["step"]
|
group_step = state["step"]
|
||||||
beta1, beta2 = group["betas"]
|
beta1, beta2 = group["betas"]
|
||||||
|
|
||||||
if target_device.type == "cpu":
|
if target_device.type == "cpu" or target_device.type == "npu":
|
||||||
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
|
assert state["exp_avg"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
|
||||||
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
|
assert state["exp_avg_sq"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
|
||||||
self._pre_update(p, "exp_avg", "exp_avg_sq")
|
self._pre_update(p, "exp_avg", "exp_avg_sq")
|
||||||
if p.grad.dtype is torch.bfloat16:
|
if p.grad.dtype is torch.bfloat16 or p.grad.device.type == "npu":
|
||||||
# cpu adam kernel does not support bf16 now
|
# cpu adam kernel does not support bf16 now
|
||||||
bias_correction1 = 1 - beta1 ** state["step"]
|
bias_correction1 = 1 - beta1 ** state["step"]
|
||||||
bias_correction2 = 1 - beta2 ** state["step"]
|
bias_correction2 = 1 - beta2 ** state["step"]
|
||||||
|
@ -10,7 +10,7 @@ from torch.utils._pytree import tree_map
|
|||||||
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
|
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
|
||||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
|
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
|
||||||
from .base import PipelineSchedule
|
from .base import PipelineSchedule
|
||||||
|
@ -9,7 +9,7 @@ from torch.utils._pytree import tree_map
|
|||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||||
from .base import PipelineSchedule
|
from .base import PipelineSchedule
|
||||||
|
@ -9,7 +9,7 @@ from torch.utils._pytree import tree_map
|
|||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
from ._utils import (
|
from ._utils import (
|
||||||
detach,
|
detach,
|
||||||
|
@ -2,16 +2,19 @@
|
|||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import torch.nn as nn
|
|
||||||
from colossalai.lazy import LazyInitContext
|
|
||||||
from ._operation import hook_paramter_in_backward
|
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
|
||||||
|
from ._operation import hook_paramter_in_backward
|
||||||
from .utils import SeqParallelUtils
|
from .utils import SeqParallelUtils
|
||||||
|
|
||||||
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
|
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||||
|
|
||||||
EnableFastLayerNorm = True
|
EnableFastLayerNorm = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
EnableFastLayerNorm = False
|
EnableFastLayerNorm = False
|
||||||
@ -19,10 +22,27 @@ except ImportError:
|
|||||||
try:
|
try:
|
||||||
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
|
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
|
||||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||||
|
|
||||||
|
class FusedLayerNormWithHook(ApexFusedLayerNorm):
|
||||||
|
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||||
|
super().__init__(normalized_shape, eps, elementwise_affine)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = super().forward(input)
|
||||||
|
output = hook_paramter_in_backward(output, self.weight, self.bias)
|
||||||
|
return output
|
||||||
|
|
||||||
|
class FusedRMSNormWithHook(ApexFusedRMSNorm):
|
||||||
|
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||||
|
super().__init__(normalized_shape, eps, elementwise_affine)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = super().forward(input)
|
||||||
|
output = hook_paramter_in_backward(output, self.weight)
|
||||||
|
return output
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
warnings.warn(
|
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel")
|
||||||
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel"
|
|
||||||
)
|
|
||||||
|
|
||||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||||
1024,
|
1024,
|
||||||
@ -52,6 +72,7 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
if EnableFastLayerNorm:
|
if EnableFastLayerNorm:
|
||||||
|
|
||||||
class FastLayerNormWithHook(FastLayerNorm):
|
class FastLayerNormWithHook(FastLayerNorm):
|
||||||
def __init__(self, hidden_size, eps=0.00001):
|
def __init__(self, hidden_size, eps=0.00001):
|
||||||
super().__init__(hidden_size, eps)
|
super().__init__(hidden_size, eps)
|
||||||
@ -61,24 +82,6 @@ if EnableFastLayerNorm:
|
|||||||
output = hook_paramter_in_backward(output, self.weight, self.bias)
|
output = hook_paramter_in_backward(output, self.weight, self.bias)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
class FusedLayerNormWithHook(ApexFusedLayerNorm):
|
|
||||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
|
||||||
super().__init__(normalized_shape, eps, elementwise_affine)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
output = super().forward(input)
|
|
||||||
output = hook_paramter_in_backward(output, self.weight, self.bias)
|
|
||||||
return output
|
|
||||||
|
|
||||||
class FusedRMSNormWithHook(ApexFusedRMSNorm):
|
|
||||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
|
||||||
super().__init__(normalized_shape, eps, elementwise_affine)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
output = super().forward(input)
|
|
||||||
output = hook_paramter_in_backward(output, self.weight)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLayerNorm(ABC):
|
class BaseLayerNorm(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -244,6 +247,7 @@ class FusedRMSNorm(BaseLayerNorm):
|
|||||||
"""
|
"""
|
||||||
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
|
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FusedRMSNorm is not implemented as a physical class. "
|
"FusedRMSNorm is not implemented as a physical class. "
|
||||||
@ -264,7 +268,7 @@ class FusedRMSNorm(BaseLayerNorm):
|
|||||||
nn.Module: FusedRMSNorm module.
|
nn.Module: FusedRMSNorm module.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
pass
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
|
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
|
||||||
@ -282,7 +286,9 @@ class FusedRMSNorm(BaseLayerNorm):
|
|||||||
eps = module.eps
|
eps = module.eps
|
||||||
elementwise_affine = module.elementwise_affine
|
elementwise_affine = module.elementwise_affine
|
||||||
|
|
||||||
rmsnorm = FusedRMSNormWithHook(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
|
rmsnorm = FusedRMSNormWithHook(
|
||||||
|
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
|
||||||
|
)
|
||||||
|
|
||||||
rmsnorm.weight = module.weight
|
rmsnorm.weight = module.weight
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from .common import (
|
|||||||
is_ddp_ignored,
|
is_ddp_ignored,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from .cuda import empty_cache, get_current_device, set_device, set_to_cuda, synchronize
|
from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize
|
||||||
from .multi_tensor_apply import multi_tensor_applier
|
from .multi_tensor_apply import multi_tensor_applier
|
||||||
from .tensor_detector import TensorDetector
|
from .tensor_detector import TensorDetector
|
||||||
from .timer import MultiTimer, Timer
|
from .timer import MultiTimer, Timer
|
||||||
@ -29,4 +29,5 @@ __all__ = [
|
|||||||
"set_seed",
|
"set_seed",
|
||||||
"is_ddp_ignored",
|
"is_ddp_ignored",
|
||||||
"set_device",
|
"set_device",
|
||||||
|
"IS_NPU_AVAILABLE",
|
||||||
]
|
]
|
||||||
|
@ -1,56 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
|
|
||||||
def set_to_cuda(models):
|
|
||||||
"""Send model to gpu.
|
|
||||||
|
|
||||||
:param models: nn.module or a list of module
|
|
||||||
"""
|
|
||||||
if isinstance(models, list) and len(models) > 1:
|
|
||||||
ret = []
|
|
||||||
for model in models:
|
|
||||||
ret.append(model.to(get_current_device()))
|
|
||||||
return ret
|
|
||||||
elif isinstance(models, list):
|
|
||||||
return models[0].to(get_current_device())
|
|
||||||
else:
|
|
||||||
return models.to(get_current_device())
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_device() -> torch.device:
|
|
||||||
"""
|
|
||||||
Returns currently selected device (gpu/cpu).
|
|
||||||
If cuda available, return gpu, otherwise return cpu.
|
|
||||||
"""
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
|
||||||
else:
|
|
||||||
return torch.device("cpu")
|
|
||||||
|
|
||||||
|
|
||||||
def synchronize():
|
|
||||||
"""Similar to cuda.synchronize().
|
|
||||||
Waits for all kernels in all streams on a CUDA device to complete.
|
|
||||||
"""
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
def empty_cache():
|
|
||||||
"""Similar to cuda.empty_cache()
|
|
||||||
Releases all unoccupied cached memory currently held by the caching allocator.
|
|
||||||
"""
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def set_device(index: Optional[int] = None) -> None:
|
|
||||||
if index is None:
|
|
||||||
index = dist.get_rank() % torch.cuda.device_count()
|
|
||||||
torch.cuda.set_device(index)
|
|
207
colossalai/utils/device.py
Normal file
207
colossalai/utils/device.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
IS_NPU_AVAILABLE: bool = False
|
||||||
|
try:
|
||||||
|
import torch_npu # noqa
|
||||||
|
|
||||||
|
IS_NPU_AVAILABLE = torch.npu.is_available()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def set_to_cuda(models):
|
||||||
|
"""Send model to gpu.
|
||||||
|
|
||||||
|
:param models: nn.module or a list of module
|
||||||
|
"""
|
||||||
|
if isinstance(models, list) and len(models) > 1:
|
||||||
|
ret = []
|
||||||
|
for model in models:
|
||||||
|
ret.append(model.to(get_current_device()))
|
||||||
|
return ret
|
||||||
|
elif isinstance(models, list):
|
||||||
|
return models[0].to(get_current_device())
|
||||||
|
else:
|
||||||
|
return models.to(get_current_device())
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_device() -> torch.device:
|
||||||
|
"""
|
||||||
|
Returns currently selected device (gpu/cpu).
|
||||||
|
If cuda available, return gpu, otherwise return cpu.
|
||||||
|
"""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||||
|
elif IS_NPU_AVAILABLE:
|
||||||
|
return torch.device(f"npu:{torch.npu.current_device()}")
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def _dispatch_device_func(fn_name: str, *args, **kwargs):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return getattr(torch.cuda, fn_name)(*args, **kwargs)
|
||||||
|
elif IS_NPU_AVAILABLE:
|
||||||
|
return getattr(torch.npu, fn_name)(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("No device available")
|
||||||
|
|
||||||
|
|
||||||
|
# device semantics
|
||||||
|
|
||||||
|
|
||||||
|
def can_device_access_peer(device, peer_device) -> bool:
|
||||||
|
return _dispatch_device_func("can_device_access_peer", device, peer_device)
|
||||||
|
|
||||||
|
|
||||||
|
def current_device() -> int:
|
||||||
|
return _dispatch_device_func("current_device")
|
||||||
|
|
||||||
|
|
||||||
|
def current_stream(device=None):
|
||||||
|
return _dispatch_device_func("current_stream", device)
|
||||||
|
|
||||||
|
|
||||||
|
def default_stream(device=None):
|
||||||
|
return _dispatch_device_func("default_stream", device)
|
||||||
|
|
||||||
|
|
||||||
|
def device_count() -> int:
|
||||||
|
return _dispatch_device_func("device_count")
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_capability(device=None) -> Tuple[int, int]:
|
||||||
|
return _dispatch_device_func("get_device_capability", device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_name(device=None) -> str:
|
||||||
|
return _dispatch_device_func("get_device_name", device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_properties(device):
|
||||||
|
return _dispatch_device_func("get_device_properties", device)
|
||||||
|
|
||||||
|
|
||||||
|
def set_device(index: Optional[int] = None) -> None:
|
||||||
|
if index is None:
|
||||||
|
index = dist.get_rank() % device_count()
|
||||||
|
_dispatch_device_func("set_device", index)
|
||||||
|
|
||||||
|
|
||||||
|
def set_stream(stream_):
|
||||||
|
return _dispatch_device_func("set_stream", stream_)
|
||||||
|
|
||||||
|
|
||||||
|
def stream(stream_):
|
||||||
|
return _dispatch_device_func("stream", stream_)
|
||||||
|
|
||||||
|
|
||||||
|
def synchronize():
|
||||||
|
return _dispatch_device_func("synchronize")
|
||||||
|
|
||||||
|
|
||||||
|
def utilization(device=None) -> int:
|
||||||
|
return _dispatch_device_func("utilization", device)
|
||||||
|
|
||||||
|
|
||||||
|
# random number generator
|
||||||
|
|
||||||
|
|
||||||
|
def get_rng_state(device="cuda") -> torch.Tensor:
|
||||||
|
return _dispatch_device_func("get_rng_state", device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rng_state_all() -> List[torch.Tensor]:
|
||||||
|
return _dispatch_device_func("get_rng_state_all")
|
||||||
|
|
||||||
|
|
||||||
|
def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None:
|
||||||
|
return _dispatch_device_func("set_rng_state", new_state, device)
|
||||||
|
|
||||||
|
|
||||||
|
def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None:
|
||||||
|
return _dispatch_device_func("set_rng_state_all", new_states)
|
||||||
|
|
||||||
|
|
||||||
|
def manual_seed(seed: int) -> None:
|
||||||
|
return _dispatch_device_func("manual_seed", seed)
|
||||||
|
|
||||||
|
|
||||||
|
def manual_seed_all(seed: int) -> None:
|
||||||
|
return _dispatch_device_func("manual_seed_all", seed)
|
||||||
|
|
||||||
|
|
||||||
|
def seed() -> None:
|
||||||
|
return _dispatch_device_func("seed")
|
||||||
|
|
||||||
|
|
||||||
|
def seed_all() -> None:
|
||||||
|
return _dispatch_device_func("seed_all")
|
||||||
|
|
||||||
|
|
||||||
|
def initial_seed() -> int:
|
||||||
|
return _dispatch_device_func("initial_seed")
|
||||||
|
|
||||||
|
|
||||||
|
# streams and events
|
||||||
|
|
||||||
|
|
||||||
|
def Stream(device=None, priority=0, **kwargs):
|
||||||
|
return _dispatch_device_func("Stream", device, priority, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
|
||||||
|
return _dispatch_device_func("Event", enable_timing, blocking, interprocess)
|
||||||
|
|
||||||
|
|
||||||
|
# memory management
|
||||||
|
|
||||||
|
|
||||||
|
def empty_cache() -> None:
|
||||||
|
return _dispatch_device_func("empty_cache")
|
||||||
|
|
||||||
|
|
||||||
|
def memory_stats(device=None) -> Dict[str, Any]:
|
||||||
|
return _dispatch_device_func("memory_stats", device)
|
||||||
|
|
||||||
|
|
||||||
|
def memory_summary(device=None, abbreviated=False) -> str:
|
||||||
|
return _dispatch_device_func("memory_summary", device, abbreviated)
|
||||||
|
|
||||||
|
|
||||||
|
def memory_snapshot():
|
||||||
|
return _dispatch_device_func("memory_snapshot")
|
||||||
|
|
||||||
|
|
||||||
|
def memory_allocated(device=None) -> int:
|
||||||
|
return _dispatch_device_func("memory_allocated", device)
|
||||||
|
|
||||||
|
|
||||||
|
def max_memory_allocated(device=None) -> int:
|
||||||
|
return _dispatch_device_func("max_memory_allocated", device)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_max_memory_allocated(device=None) -> None:
|
||||||
|
return _dispatch_device_func("reset_max_memory_allocated", device)
|
||||||
|
|
||||||
|
|
||||||
|
def memory_reserved(device=None) -> int:
|
||||||
|
return _dispatch_device_func("memory_reserved", device)
|
||||||
|
|
||||||
|
|
||||||
|
def max_memory_reserved(device=None) -> int:
|
||||||
|
return _dispatch_device_func("max_memory_reserved", device)
|
||||||
|
|
||||||
|
|
||||||
|
def set_per_process_memory_fraction(fraction: float, device=None) -> None:
|
||||||
|
return _dispatch_device_func("set_per_process_memory_fraction", fraction, device)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_peak_memory_stats(device=None) -> None:
|
||||||
|
return _dispatch_device_func("reset_peak_memory_stats", device)
|
@ -3,7 +3,7 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from .cuda import synchronize
|
from .device import synchronize
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
|
@ -7,6 +7,7 @@ import torch.distributed as dist
|
|||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.utils.device import IS_NPU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
class TensorState(Enum):
|
class TensorState(Enum):
|
||||||
@ -172,7 +173,7 @@ class Chunk:
|
|||||||
|
|
||||||
if self.chunk_temp is not None:
|
if self.chunk_temp is not None:
|
||||||
# this chunk is not closed
|
# this chunk is not closed
|
||||||
if self.chunk_temp.device.type == "cuda":
|
if self.chunk_temp.device.type == "cuda" or self.chunk_temp.device.type == "npu":
|
||||||
cuda_memory += self.chunk_mem
|
cuda_memory += self.chunk_mem
|
||||||
else:
|
else:
|
||||||
cpu_memory += self.chunk_mem
|
cpu_memory += self.chunk_mem
|
||||||
@ -191,10 +192,8 @@ class Chunk:
|
|||||||
if self.chunk_temp is not None:
|
if self.chunk_temp is not None:
|
||||||
return self.chunk_temp.device.type
|
return self.chunk_temp.device.type
|
||||||
else:
|
else:
|
||||||
if self.is_gathered:
|
if self.is_gathered or self.cuda_shard is not None:
|
||||||
return "cuda"
|
return "npu" if IS_NPU_AVAILABLE else "cuda"
|
||||||
elif self.cuda_shard is not None:
|
|
||||||
return "cuda"
|
|
||||||
else:
|
else:
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
@ -329,12 +328,12 @@ class Chunk:
|
|||||||
# when the current chunk is not synchronized with the optimizer
|
# when the current chunk is not synchronized with the optimizer
|
||||||
# just use another way for the movement
|
# just use another way for the movement
|
||||||
if not self.optim_sync_flag:
|
if not self.optim_sync_flag:
|
||||||
assert device.type == "cuda", "each chunk should first be moved to CUDA"
|
assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA"
|
||||||
self.__paired_shard_move()
|
self.__paired_shard_move()
|
||||||
self.optim_sync_flag = True
|
self.optim_sync_flag = True
|
||||||
return
|
return
|
||||||
|
|
||||||
if device.type == "cuda":
|
if device.type == "cuda" or device.type == "npu":
|
||||||
assert device == get_current_device(), "can't move chunk to another device"
|
assert device == get_current_device(), "can't move chunk to another device"
|
||||||
|
|
||||||
if self.cuda_shard:
|
if self.cuda_shard:
|
||||||
@ -484,7 +483,7 @@ class Chunk:
|
|||||||
assert friend_chunk.is_gathered is True
|
assert friend_chunk.is_gathered is True
|
||||||
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
|
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
|
||||||
self.optim_sync_flag = True
|
self.optim_sync_flag = True
|
||||||
elif friend_chunk.device_type == "cuda" and self.device_type == "cuda":
|
elif friend_chunk.device_type in ("cuda", "npu") and self.device_type in ("cuda", "npu"):
|
||||||
self.cuda_shard.copy_(friend_chunk.cuda_shard)
|
self.cuda_shard.copy_(friend_chunk.cuda_shard)
|
||||||
self.optim_sync_flag = True
|
self.optim_sync_flag = True
|
||||||
self.cpu_vis_flag = False
|
self.cpu_vis_flag = False
|
||||||
|
@ -206,7 +206,10 @@ class ChunkManager:
|
|||||||
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
|
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
|
||||||
"""
|
"""
|
||||||
assert tensor not in self.tensor_chunk_map
|
assert tensor not in self.tensor_chunk_map
|
||||||
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
device_type = tensor.device.type
|
||||||
|
if device_type == "npu":
|
||||||
|
device_type = "cuda"
|
||||||
|
self.total_mem[device_type] += tensor.numel() * tensor.element_size()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
msg = [
|
msg = [
|
||||||
|
@ -10,14 +10,24 @@ import torch.nn as nn
|
|||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.distributed.distributed_c10d import _get_default_group
|
from torch.distributed.distributed_c10d import _get_default_group
|
||||||
|
|
||||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
|
||||||
from colossalai.interface import ModelWrapper
|
from colossalai.interface import ModelWrapper
|
||||||
from colossalai.lazy import LazyTensor
|
from colossalai.lazy import LazyTensor
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
|
from colossalai.tensor.d_tensor import (
|
||||||
|
distribute_tensor,
|
||||||
|
distribute_tensor_with_customization,
|
||||||
|
get_device_mesh,
|
||||||
|
get_global_shape,
|
||||||
|
get_sharding_spec,
|
||||||
|
init_as_dtensor,
|
||||||
|
init_tensor_as_customization_distributed,
|
||||||
|
is_customized_distributed_tensor,
|
||||||
|
is_distributed_tensor,
|
||||||
|
)
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
|
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
|
||||||
from colossalai.checkpoint_io.utils import gather_distributed_param
|
|
||||||
|
|
||||||
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
||||||
from .gemini_hook import GeminiZeROHook
|
from .gemini_hook import GeminiZeROHook
|
||||||
@ -25,18 +35,6 @@ from .gemini_mgr import GeminiManager
|
|||||||
from .memory_tracer import MemStats, OrderedParamGenerator
|
from .memory_tracer import MemStats, OrderedParamGenerator
|
||||||
from .utils import get_temp_total_chunk_on_cuda
|
from .utils import get_temp_total_chunk_on_cuda
|
||||||
|
|
||||||
from colossalai.tensor.d_tensor import (
|
|
||||||
distribute_tensor,
|
|
||||||
distribute_tensor_with_customization,
|
|
||||||
init_tensor_as_customization_distributed,
|
|
||||||
get_device_mesh,
|
|
||||||
get_sharding_spec,
|
|
||||||
is_customized_distributed_tensor,
|
|
||||||
is_distributed_tensor,
|
|
||||||
get_global_shape,
|
|
||||||
init_as_dtensor
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -162,7 +160,7 @@ class GeminiDDP(ModelWrapper):
|
|||||||
self._init_chunks(
|
self._init_chunks(
|
||||||
param_order=param_order,
|
param_order=param_order,
|
||||||
strict_ddp_mode=strict_ddp_mode,
|
strict_ddp_mode=strict_ddp_mode,
|
||||||
cpu_offload=self.gemini_manager.policy_name != "cuda",
|
cpu_offload=not (self.gemini_manager.policy_name == "static" and offload_param_frac == 0),
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
@ -453,12 +451,13 @@ class GeminiDDP(ModelWrapper):
|
|||||||
global_shape = get_global_shape(tensor)
|
global_shape = get_global_shape(tensor)
|
||||||
device_mesh = get_device_mesh(tensor)
|
device_mesh = get_device_mesh(tensor)
|
||||||
shard_spec = get_sharding_spec(tensor)
|
shard_spec = get_sharding_spec(tensor)
|
||||||
record_tensor = init_as_dtensor(record_tensor,
|
record_tensor = init_as_dtensor(
|
||||||
device_mesh=device_mesh,
|
record_tensor, device_mesh=device_mesh, sharding_spec=shard_spec, global_shape=global_shape
|
||||||
sharding_spec=shard_spec,
|
)
|
||||||
global_shape = global_shape)
|
|
||||||
elif is_customized_distributed_tensor(tensor):
|
elif is_customized_distributed_tensor(tensor):
|
||||||
init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn)
|
init_tensor_as_customization_distributed(
|
||||||
|
record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn
|
||||||
|
)
|
||||||
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
|
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
|
||||||
|
|
||||||
assert tensor not in chunk_to_save_data
|
assert tensor not in chunk_to_save_data
|
||||||
@ -634,7 +633,15 @@ class GeminiDDP(ModelWrapper):
|
|||||||
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
|
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
|
||||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||||
|
|
||||||
def load(param_name, dest_tensor, copy_func, source_device_mesh=None, source_sharding_spec=None, shard_fn=None, gather_fn=None):
|
def load(
|
||||||
|
param_name,
|
||||||
|
dest_tensor,
|
||||||
|
copy_func,
|
||||||
|
source_device_mesh=None,
|
||||||
|
source_sharding_spec=None,
|
||||||
|
shard_fn=None,
|
||||||
|
gather_fn=None,
|
||||||
|
):
|
||||||
state_key = prefix + param_name
|
state_key = prefix + param_name
|
||||||
if state_key in state_dict:
|
if state_key in state_dict:
|
||||||
input_param = state_dict[state_key]
|
input_param = state_dict[state_key]
|
||||||
@ -642,7 +649,9 @@ class GeminiDDP(ModelWrapper):
|
|||||||
if source_device_mesh is not None and source_sharding_spec is not None:
|
if source_device_mesh is not None and source_sharding_spec is not None:
|
||||||
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
|
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
|
||||||
elif shard_fn is not None and gather_fn is not None:
|
elif shard_fn is not None and gather_fn is not None:
|
||||||
input_param = distribute_tensor_with_customization(input_param, shard_fn=shard_fn, gather_fn=gather_fn)
|
input_param = distribute_tensor_with_customization(
|
||||||
|
input_param, shard_fn=shard_fn, gather_fn=gather_fn
|
||||||
|
)
|
||||||
|
|
||||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||||
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
|
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
|
||||||
@ -687,7 +696,6 @@ class GeminiDDP(ModelWrapper):
|
|||||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
|
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
|
||||||
|
|
||||||
for tensor, tensor_info in chunk.tensors_info.items():
|
for tensor, tensor_info in chunk.tensors_info.items():
|
||||||
|
|
||||||
source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None
|
source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None
|
||||||
if is_distributed_tensor(tensor):
|
if is_distributed_tensor(tensor):
|
||||||
# shard the input param
|
# shard the input param
|
||||||
@ -699,7 +707,15 @@ class GeminiDDP(ModelWrapper):
|
|||||||
|
|
||||||
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
|
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
|
||||||
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
|
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
|
||||||
load(parameter_name, tensor, partial(load_parameter, parameter_slice), source_device_mesh, source_sharding_spec, shard_fn, gather_fn)
|
load(
|
||||||
|
parameter_name,
|
||||||
|
tensor,
|
||||||
|
partial(load_parameter, parameter_slice),
|
||||||
|
source_device_mesh,
|
||||||
|
source_sharding_spec,
|
||||||
|
shard_fn,
|
||||||
|
gather_fn,
|
||||||
|
)
|
||||||
|
|
||||||
if chunk.is_gathered:
|
if chunk.is_gathered:
|
||||||
chunk.cuda_global_chunk.copy_(temp_chunk)
|
chunk.cuda_global_chunk.copy_(temp_chunk)
|
||||||
@ -799,7 +815,7 @@ class GeminiDDP(ModelWrapper):
|
|||||||
for buffer in self.module.buffers():
|
for buffer in self.module.buffers():
|
||||||
if isinstance(buffer, LazyTensor):
|
if isinstance(buffer, LazyTensor):
|
||||||
buffer.materialize()
|
buffer.materialize()
|
||||||
buffer.data = buffer.cuda()
|
buffer.data = buffer.to(get_current_device())
|
||||||
if torch.is_floating_point(buffer):
|
if torch.is_floating_point(buffer):
|
||||||
buffer.data = buffer.to(self.mixed_precision)
|
buffer.data = buffer.to(self.mixed_precision)
|
||||||
|
|
||||||
|
@ -17,9 +17,7 @@ class GeminiManager:
|
|||||||
https://arxiv.org/abs/2108.05818
|
https://arxiv.org/abs/2108.05818
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'.
|
placement_policy (str): Which device to place *held* tensors. It can be 'static' and 'auto'.
|
||||||
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
|
|
||||||
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
|
|
||||||
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
|
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
|
||||||
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
|
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
|
||||||
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
||||||
@ -121,7 +119,7 @@ class GeminiManager:
|
|||||||
start = time()
|
start = time()
|
||||||
cuda_demand = 0
|
cuda_demand = 0
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if chunk.device_type == "cuda":
|
if chunk.device_type == "cuda" or chunk.device_type == "npu":
|
||||||
if chunk.is_gathered:
|
if chunk.is_gathered:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
@ -7,31 +7,29 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
|
|
||||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||||
from colossalai.checkpoint_io.utils import StateDictSharder
|
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
||||||
|
from colossalai.tensor.d_tensor import (
|
||||||
|
distribute_tensor,
|
||||||
|
distribute_tensor_with_customization,
|
||||||
|
get_device_mesh,
|
||||||
|
get_sharding_spec,
|
||||||
|
init_as_dtensor,
|
||||||
|
init_tensor_as_customization_distributed,
|
||||||
|
is_customized_distributed_tensor,
|
||||||
|
is_distributed_tensor,
|
||||||
|
)
|
||||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||||
|
|
||||||
from .chunk import Chunk, ChunkManager
|
from .chunk import Chunk, ChunkManager
|
||||||
from .gemini_ddp import GeminiDDP
|
from .gemini_ddp import GeminiDDP
|
||||||
from colossalai.checkpoint_io.utils import gather_distributed_param
|
|
||||||
from colossalai.tensor.d_tensor import (
|
|
||||||
distribute_tensor,
|
|
||||||
distribute_tensor_with_customization,
|
|
||||||
init_tensor_as_customization_distributed,
|
|
||||||
get_device_mesh,
|
|
||||||
get_sharding_spec,
|
|
||||||
is_customized_distributed_tensor,
|
|
||||||
is_distributed_tensor,
|
|
||||||
get_global_shape,
|
|
||||||
init_as_dtensor
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
|
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
|
||||||
|
|
||||||
@ -312,7 +310,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
chunk16 = self.param_to_chunk16[fake_param]
|
chunk16 = self.param_to_chunk16[fake_param]
|
||||||
chunk32 = chunk16.paired_chunk
|
chunk32 = chunk16.paired_chunk
|
||||||
|
|
||||||
if chunk32.device_type == "cuda":
|
if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
|
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
|
||||||
@ -326,7 +324,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
for fake_param in group["params"]:
|
for fake_param in group["params"]:
|
||||||
chunk16 = self.param_to_chunk16[fake_param]
|
chunk16 = self.param_to_chunk16[fake_param]
|
||||||
chunk32 = chunk16.paired_chunk
|
chunk32 = chunk16.paired_chunk
|
||||||
if chunk32.device_type == "cuda":
|
if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
|
||||||
state = self.optim.state[fake_param]
|
state = self.optim.state[fake_param]
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
@ -479,13 +477,17 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
|
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
|
||||||
if is_dtensor:
|
if is_dtensor:
|
||||||
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
|
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
|
||||||
state_tensor = init_as_dtensor(state_tensor,
|
state_tensor = init_as_dtensor(
|
||||||
|
state_tensor,
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
sharding_spec=shard_spec,
|
sharding_spec=shard_spec,
|
||||||
global_shape = global_shape)
|
global_shape=global_shape,
|
||||||
|
)
|
||||||
elif is_customized_distributed:
|
elif is_customized_distributed:
|
||||||
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
|
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
|
||||||
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
|
init_tensor_as_customization_distributed(
|
||||||
|
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
|
||||||
|
)
|
||||||
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
||||||
|
|
||||||
collected_states[state_name] = state_tensor.reshape(global_shape)
|
collected_states[state_name] = state_tensor.reshape(global_shape)
|
||||||
@ -533,13 +535,14 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
|
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
|
||||||
if is_dtensor:
|
if is_dtensor:
|
||||||
state_tensor = state_tensor.to(param.device)
|
state_tensor = state_tensor.to(param.device)
|
||||||
state_tensor = init_as_dtensor(state_tensor,
|
state_tensor = init_as_dtensor(
|
||||||
sharding_spec=shard_spec,
|
state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
|
||||||
device_mesh=device_mesh,
|
)
|
||||||
global_shape=global_shape)
|
|
||||||
elif is_customized_distributed:
|
elif is_customized_distributed:
|
||||||
state_tensor = state_tensor.to(param.device)
|
state_tensor = state_tensor.to(param.device)
|
||||||
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
|
init_tensor_as_customization_distributed(
|
||||||
|
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
|
||||||
|
)
|
||||||
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
||||||
|
|
||||||
return collected_states
|
return collected_states
|
||||||
@ -548,7 +551,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
self,
|
self,
|
||||||
param_id: int,
|
param_id: int,
|
||||||
state_names: list,
|
state_names: list,
|
||||||
device: torch.device = torch.device("cuda"),
|
device: torch.device = get_current_device(),
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
@ -12,6 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
import colossalai.utils.device as device_utils
|
||||||
from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||||
BF16MixedPrecisionMixin,
|
BF16MixedPrecisionMixin,
|
||||||
FP16MixedPrecisionMixin,
|
FP16MixedPrecisionMixin,
|
||||||
@ -22,7 +23,7 @@ from colossalai.logging import get_dist_logger
|
|||||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||||
|
|
||||||
# from colossalai.tensor import ColoParameter, ProcessGroup
|
# from colossalai.tensor import ColoParameter, ProcessGroup
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device
|
||||||
|
|
||||||
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
||||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||||
@ -182,7 +183,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
# intialize communication stream for
|
# intialize communication stream for
|
||||||
# communication-compuation overlapping
|
# communication-compuation overlapping
|
||||||
if self._overlap_communication:
|
if self._overlap_communication:
|
||||||
self._comm_stream = torch.cuda.Stream()
|
self._comm_stream = device_utils.Stream()
|
||||||
|
|
||||||
# reduction hook is only used if overlapping communication
|
# reduction hook is only used if overlapping communication
|
||||||
# or stage 2 is used
|
# or stage 2 is used
|
||||||
@ -216,7 +217,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
return len(self._working_param_groups)
|
return len(self._working_param_groups)
|
||||||
|
|
||||||
def _sanity_checks(self):
|
def _sanity_checks(self):
|
||||||
assert torch.cuda.is_available(), "CUDA is required"
|
assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required"
|
||||||
for param_group in self.optim.param_groups:
|
for param_group in self.optim.param_groups:
|
||||||
group_params = param_group["params"]
|
group_params = param_group["params"]
|
||||||
for param in group_params:
|
for param in group_params:
|
||||||
@ -339,11 +340,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
if len(moe_grad_list) > 0:
|
if len(moe_grad_list) > 0:
|
||||||
moe_flat_grads.record_stream(stream)
|
moe_flat_grads.record_stream(stream)
|
||||||
# waiting for ops in the default stream finishing
|
# waiting for ops in the default stream finishing
|
||||||
stream.wait_stream(torch.cuda.current_stream())
|
stream.wait_stream(device_utils.current_stream())
|
||||||
else:
|
else:
|
||||||
stream = torch.cuda.current_stream()
|
stream = device_utils.current_stream()
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
with device_utils.stream(stream):
|
||||||
group_id = self._bucket_store.current_group_id
|
group_id = self._bucket_store.current_group_id
|
||||||
|
|
||||||
if self.moe_extra_dp_pg is None:
|
if self.moe_extra_dp_pg is None:
|
||||||
@ -485,7 +486,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
|
|
||||||
# clear reduced grads
|
# clear reduced grads
|
||||||
if self._overlap_communication:
|
if self._overlap_communication:
|
||||||
torch.cuda.synchronize()
|
device_utils.synchronize()
|
||||||
|
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
|
|
||||||
@ -504,7 +505,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
|
|
||||||
# clear reduced grads
|
# clear reduced grads
|
||||||
if self._overlap_communication:
|
if self._overlap_communication:
|
||||||
torch.cuda.synchronize()
|
device_utils.synchronize()
|
||||||
|
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
|
|
||||||
@ -620,22 +621,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||||
|
|
||||||
# update working partition updated by the current rank
|
# update working partition updated by the current rank
|
||||||
|
device = get_current_device()
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||||
for idx, splited_param in enumerate(master_working_param):
|
for idx, splited_param in enumerate(master_working_param):
|
||||||
working_param = real_working_params[group_id][idx]
|
working_param = real_working_params[group_id][idx]
|
||||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
||||||
all_splited_param = [
|
all_splited_param = [
|
||||||
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype)
|
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
||||||
for _ in range(self.moe_extra_dp_pg_size)
|
for _ in range(self.moe_extra_dp_pg_size)
|
||||||
]
|
]
|
||||||
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.moe_extra_dp_pg)
|
dist.all_gather(
|
||||||
|
all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
all_splited_param = [
|
all_splited_param = [
|
||||||
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype)
|
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
||||||
for _ in range(self._world_size)
|
for _ in range(self._world_size)
|
||||||
]
|
]
|
||||||
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
|
dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
|
||||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||||
|
|
||||||
@ -657,7 +661,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
norm_type = float(norm_type)
|
norm_type = float(norm_type)
|
||||||
if norm_type == inf:
|
if norm_type == inf:
|
||||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float)
|
||||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
||||||
total_norm = total_norm_cuda.item()
|
total_norm = total_norm_cuda.item()
|
||||||
|
|
||||||
@ -668,7 +672,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
total_norm_exponentiated += grad_norm_exponentiated
|
total_norm_exponentiated += grad_norm_exponentiated
|
||||||
|
|
||||||
# Sum across all model parallel GPUs.
|
# Sum across all model parallel GPUs.
|
||||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
total_norm_exponentiated_cuda = torch.tensor(
|
||||||
|
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float
|
||||||
|
)
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
|
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
|
||||||
)
|
)
|
||||||
@ -759,6 +765,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
Dict: the pytorch form state_dict
|
Dict: the pytorch form state_dict
|
||||||
"""
|
"""
|
||||||
zero_state = dict()
|
zero_state = dict()
|
||||||
|
device = get_current_device()
|
||||||
for param, state in self.optim.state.items():
|
for param, state in self.optim.state.items():
|
||||||
zero_state[param] = copy.deepcopy(state)
|
zero_state[param] = copy.deepcopy(state)
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
@ -766,14 +773,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
working_param = self._param_store.master_to_working_param[id(param)]
|
working_param = self._param_store.master_to_working_param[id(param)]
|
||||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
||||||
gather_tensor = [
|
gather_tensor = [
|
||||||
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
|
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
|
||||||
]
|
]
|
||||||
dist.all_gather(gather_tensor, v.cuda(), group=self.moe_extra_dp_pg)
|
dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg)
|
||||||
else:
|
else:
|
||||||
gather_tensor = [
|
gather_tensor = [
|
||||||
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
|
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
|
||||||
]
|
]
|
||||||
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
|
dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg)
|
||||||
param_state = (
|
param_state = (
|
||||||
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||||
)
|
)
|
||||||
@ -820,6 +827,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
ret_block = dict()
|
ret_block = dict()
|
||||||
ret_block_size = 0
|
ret_block_size = 0
|
||||||
|
|
||||||
|
device = get_current_device()
|
||||||
local_states = self.optim.state_dict()["state"]
|
local_states = self.optim.state_dict()["state"]
|
||||||
for param_idx, states in local_states.items():
|
for param_idx, states in local_states.items():
|
||||||
current_block_size = 0
|
current_block_size = 0
|
||||||
@ -836,14 +844,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
if isinstance(v, torch.Tensor) and k != "step":
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
||||||
state_tensor = [
|
state_tensor = [
|
||||||
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
|
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
|
||||||
]
|
]
|
||||||
dist.all_gather(state_tensor, v.cuda(), group=self.moe_extra_dp_pg)
|
dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg)
|
||||||
else:
|
else:
|
||||||
state_tensor = [
|
state_tensor = [
|
||||||
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
|
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
|
||||||
]
|
]
|
||||||
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
|
dist.all_gather(state_tensor, v.to(device), group=self.dp_pg)
|
||||||
state_tensor = (
|
state_tensor = (
|
||||||
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||||
)
|
)
|
||||||
|
@ -13,6 +13,7 @@ from transformers.models.llama.configuration_llama import LlamaConfig
|
|||||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
import colossalai.utils.device as device_utils
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
@ -194,7 +195,7 @@ def main():
|
|||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||||
torch.set_default_dtype(torch.float)
|
torch.set_default_dtype(torch.float)
|
||||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||||
)
|
)
|
||||||
@ -220,7 +221,7 @@ def main():
|
|||||||
performance_evaluator.on_step_end(**batch)
|
performance_evaluator.on_step_end(**batch)
|
||||||
|
|
||||||
performance_evaluator.on_fit_end()
|
performance_evaluator.on_fit_end()
|
||||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -5,7 +5,9 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
import colossalai.utils.device as device_utils
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
|
|
||||||
def divide(x: float, y: float) -> float:
|
def divide(x: float, y: float) -> float:
|
||||||
@ -20,7 +22,7 @@ def divide(x: float, y: float) -> float:
|
|||||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return x
|
return x
|
||||||
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
tensor = torch.tensor([x], device=get_current_device())
|
||||||
dist.all_reduce(tensor)
|
dist.all_reduce(tensor)
|
||||||
tensor = tensor / world_size
|
tensor = tensor / world_size
|
||||||
return tensor.item()
|
return tensor.item()
|
||||||
@ -84,13 +86,13 @@ class PerformanceEvaluator:
|
|||||||
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
|
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return
|
return
|
||||||
torch.cuda.synchronize()
|
device_utils.synchronize()
|
||||||
self.timer.start()
|
self.timer.start()
|
||||||
|
|
||||||
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
|
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return
|
return
|
||||||
torch.cuda.synchronize()
|
device_utils.synchronize()
|
||||||
self.timer.end()
|
self.timer.end()
|
||||||
|
|
||||||
batch_size, seq_len = input_ids.shape
|
batch_size, seq_len = input_ids.shape
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from .arm_cpu_adam import ArmCPUAdamBuilder
|
||||||
from .cpu_adam import CPUAdamBuilder
|
from .cpu_adam import CPUAdamBuilder
|
||||||
from .fused_optim import FusedOptimBuilder
|
from .fused_optim import FusedOptimBuilder
|
||||||
from .layernorm import LayerNormBuilder
|
from .layernorm import LayerNormBuilder
|
||||||
@ -29,4 +30,5 @@ __all__ = [
|
|||||||
"MultiTensorLambBuilder",
|
"MultiTensorLambBuilder",
|
||||||
"MultiTensorScaleBuilder",
|
"MultiTensorScaleBuilder",
|
||||||
"MultiTensorL2NormBuilder",
|
"MultiTensorL2NormBuilder",
|
||||||
|
"ArmCPUAdamBuilder",
|
||||||
]
|
]
|
||||||
|
34
op_builder/arm_cpu_adam.py
Normal file
34
op_builder/arm_cpu_adam.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from .builder import Builder
|
||||||
|
|
||||||
|
|
||||||
|
class ArmCPUAdamBuilder(Builder):
|
||||||
|
NAME = "arm_cpu_adam"
|
||||||
|
PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam"
|
||||||
|
ext_type = "cpu"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH)
|
||||||
|
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
|
||||||
|
|
||||||
|
# necessary 4 functions
|
||||||
|
def sources_files(self):
|
||||||
|
ret = [
|
||||||
|
self.csrc_abs_path("cpu_adam_arm.cpp"),
|
||||||
|
]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def include_dirs(self):
|
||||||
|
return [self.csrc_abs_path("includes")]
|
||||||
|
|
||||||
|
def cxx_flags(self):
|
||||||
|
extra_cxx_flags = [
|
||||||
|
"-std=c++14",
|
||||||
|
"-std=c++17",
|
||||||
|
"-g",
|
||||||
|
"-Wno-reorder",
|
||||||
|
"-fopenmp",
|
||||||
|
]
|
||||||
|
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
|
||||||
|
|
||||||
|
def nvcc_flags(self):
|
||||||
|
return []
|
@ -7,7 +7,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
|
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
|
||||||
|
|
||||||
@ -21,6 +21,8 @@ class Builder(ABC):
|
|||||||
prebuilt_import_path (str): the path where the extension is installed during pip install
|
prebuilt_import_path (str): the path where the extension is installed during pip install
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
ext_type: str = "cuda"
|
||||||
|
|
||||||
def __init__(self, name: str, prebuilt_import_path: str):
|
def __init__(self, name: str, prebuilt_import_path: str):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.prebuilt_import_path = prebuilt_import_path
|
self.prebuilt_import_path = prebuilt_import_path
|
||||||
@ -165,6 +167,7 @@ class Builder(ABC):
|
|||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# check environment
|
# check environment
|
||||||
|
if self.ext_type == "cuda":
|
||||||
self.check_runtime_build_environment()
|
self.check_runtime_build_environment()
|
||||||
|
|
||||||
# time the kernel compilation
|
# time the kernel compilation
|
||||||
@ -208,11 +211,19 @@ class Builder(ABC):
|
|||||||
|
|
||||||
return op_module
|
return op_module
|
||||||
|
|
||||||
def builder(self) -> "CUDAExtension":
|
def builder(self) -> Union["CUDAExtension", "CppExtension"]:
|
||||||
"""
|
"""
|
||||||
get a CUDAExtension instance used for setup.py
|
get a CUDAExtension instance used for setup.py
|
||||||
"""
|
"""
|
||||||
from torch.utils.cpp_extension import CUDAExtension
|
from torch.utils.cpp_extension import CppExtension, CUDAExtension
|
||||||
|
|
||||||
|
if self.ext_type == "cpp":
|
||||||
|
return CppExtension(
|
||||||
|
name=self.prebuilt_import_path,
|
||||||
|
sources=self.strip_empty_entries(self.sources_files()),
|
||||||
|
include_dirs=self.strip_empty_entries(self.include_dirs()),
|
||||||
|
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
|
||||||
|
)
|
||||||
|
|
||||||
return CUDAExtension(
|
return CUDAExtension(
|
||||||
name=self.prebuilt_import_path,
|
name=self.prebuilt_import_path,
|
||||||
|
@ -2,11 +2,14 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from torch.optim import Adam
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
import colossalai.utils.device as device_utils
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
|
# from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
@ -19,16 +22,17 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
|
|||||||
|
|
||||||
|
|
||||||
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||||
|
device = device_utils.get_current_device()
|
||||||
try:
|
try:
|
||||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
|
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
model = model_fn()
|
model = model_fn()
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||||
criterion = lambda x: x.mean()
|
criterion = lambda x: x.mean()
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
|
k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||||
@ -65,7 +69,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
|||||||
continue
|
continue
|
||||||
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
|
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
device_utils.empty_cache()
|
||||||
|
|
||||||
if err is None:
|
if err is None:
|
||||||
passed_models.append(name)
|
passed_models.append(name)
|
||||||
@ -89,7 +93,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
|
|||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_low_level_zero_plugin(early_stop: bool = True):
|
def test_low_level_zero_plugin(early_stop: bool = True):
|
||||||
spawn(run_dist, 4, early_stop=early_stop)
|
spawn(run_dist, 2, early_stop=early_stop)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -3,7 +3,7 @@ import pytest
|
|||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||||
from colossalai.testing import spawn
|
from colossalai.testing import spawn
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
|
|
||||||
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
|
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
|
||||||
|
@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
|
|||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils import set_seed
|
from colossalai.utils import set_seed
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||||
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
|
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
|
||||||
|
@ -9,7 +9,7 @@ import colossalai
|
|||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils import set_seed
|
from colossalai.utils import set_seed
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||||
from tests.kit.model_zoo import model_zoo, run_fwd
|
from tests.kit.model_zoo import model_zoo, run_fwd
|
||||||
|
@ -11,7 +11,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
|
|||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils import set_seed
|
from colossalai.utils import set_seed
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||||
from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd
|
from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd
|
||||||
|
@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
|
|||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils import set_seed
|
from colossalai.utils import set_seed
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.device import get_current_device
|
||||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||||
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
|
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
|
||||||
|
@ -9,7 +9,7 @@ from torch.testing import assert_close
|
|||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.testing import spawn
|
from colossalai.testing import spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from colossalai.utils import conditional_context
|
from colossalai.utils import conditional_context, get_current_device
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
|
|
||||||
|
|
||||||
@ -28,9 +28,9 @@ class MlpModel(nn.Module):
|
|||||||
def exam_zero_1_2_grad_acc():
|
def exam_zero_1_2_grad_acc():
|
||||||
local_rank = torch.distributed.get_rank()
|
local_rank = torch.distributed.get_rank()
|
||||||
seed_all(2009)
|
seed_all(2009)
|
||||||
|
device = get_current_device()
|
||||||
# create model
|
# create model
|
||||||
zero1_model = MlpModel().cuda()
|
zero1_model = MlpModel().to(device)
|
||||||
zero2_model = copy.deepcopy(zero1_model)
|
zero2_model = copy.deepcopy(zero1_model)
|
||||||
# create optimizer
|
# create optimizer
|
||||||
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
|
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
|
||||||
@ -43,8 +43,8 @@ def exam_zero_1_2_grad_acc():
|
|||||||
)
|
)
|
||||||
# create data
|
# create data
|
||||||
seed_all(2021 + local_rank)
|
seed_all(2021 + local_rank)
|
||||||
input_data1 = torch.randn(32, 128).cuda()
|
input_data1 = torch.randn(32, 128, device=device)
|
||||||
input_data2 = torch.randn(32, 128).cuda()
|
input_data2 = torch.randn(32, 128, device=device)
|
||||||
|
|
||||||
def fwd_bwd_func(number, cur_data, check_flag):
|
def fwd_bwd_func(number, cur_data, check_flag):
|
||||||
# zero-dp forward
|
# zero-dp forward
|
||||||
@ -71,14 +71,15 @@ def exam_zero_1_2_grad_acc():
|
|||||||
def exam_zero_1_grad_acc(sync):
|
def exam_zero_1_grad_acc(sync):
|
||||||
local_rank = torch.distributed.get_rank()
|
local_rank = torch.distributed.get_rank()
|
||||||
seed_all(2008)
|
seed_all(2008)
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
# create models
|
# create models
|
||||||
zero_model = MlpModel()
|
zero_model = MlpModel()
|
||||||
torch_model = copy.deepcopy(zero_model)
|
torch_model = copy.deepcopy(zero_model)
|
||||||
|
|
||||||
seed_all(2008)
|
seed_all(2008)
|
||||||
zero_model = zero_model.cuda()
|
zero_model = zero_model.to(device)
|
||||||
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
|
torch_model = DDP(torch_model.to(device), bucket_cap_mb=0)
|
||||||
|
|
||||||
# create optimizer
|
# create optimizer
|
||||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
|
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
|
||||||
@ -94,8 +95,8 @@ def exam_zero_1_grad_acc(sync):
|
|||||||
|
|
||||||
# create data
|
# create data
|
||||||
seed_all(2022 + local_rank)
|
seed_all(2022 + local_rank)
|
||||||
input_data1 = torch.randn(32, 128).cuda()
|
input_data1 = torch.randn(32, 128, device=device)
|
||||||
input_data2 = torch.randn(32, 128).cuda()
|
input_data2 = torch.randn(32, 128, device=device)
|
||||||
|
|
||||||
def fwd_bwd_func(no_sync, cur_data, check_flag):
|
def fwd_bwd_func(no_sync, cur_data, check_flag):
|
||||||
# zero1 fwd and bwd
|
# zero1 fwd and bwd
|
||||||
|
Loading…
Reference in New Issue
Block a user