mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[cuda] modify the fused adam, support hybrid of fp16 and fp32 (#497)
This commit is contained in:
@@ -22,7 +22,7 @@ typedef enum
|
||||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T>
|
||||
template <typename T_g, typename T_p>
|
||||
struct AdamFunctor
|
||||
{
|
||||
__device__ __forceinline__ void operator()(
|
||||
@@ -50,16 +50,16 @@ struct AdamFunctor
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T *g = (T *)tl.addresses[0][tensor_loc];
|
||||
T_g *g = (T_g *)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
||||
T *p = (T *)tl.addresses[1][tensor_loc];
|
||||
T_p *p = (T_p *)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
T *m = (T *)tl.addresses[2][tensor_loc];
|
||||
T_p *m = (T_p *)tl.addresses[2][tensor_loc];
|
||||
m += chunk_idx * chunk_size;
|
||||
|
||||
T *v = (T *)tl.addresses[3][tensor_loc];
|
||||
T_p *v = (T_p *)tl.addresses[3][tensor_loc];
|
||||
v += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
@@ -155,15 +155,15 @@ void multi_tensor_adam_cuda(
|
||||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
}
|
||||
|
||||
// Assume single type across p,g,m1,m2 now
|
||||
DISPATCH_DOUBLE_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "adam",
|
||||
DISPATCH_FLOAT_AND_HALF_FOR_G_P(
|
||||
tensor_lists[0][0].scalar_type(),
|
||||
tensor_lists[1][0].scalar_type(), 0, "adam",
|
||||
multi_tensor_apply<4>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
AdamFunctor<scalar_t_0>(),
|
||||
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(),
|
||||
beta1,
|
||||
beta2,
|
||||
bias_correction1,
|
||||
|
@@ -173,6 +173,36 @@
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
|
||||
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \
|
||||
{ \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \
|
||||
} \
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
|
||||
T val,
|
||||
|
Reference in New Issue
Block a user