[cuda] modify the fused adam, support hybrid of fp16 and fp32 (#497)

This commit is contained in:
LuGY
2022-03-25 14:15:53 +08:00
committed by GitHub
parent 920c5889a7
commit 6a3f9fda83
6 changed files with 253 additions and 143 deletions

View File

@@ -22,7 +22,7 @@ typedef enum
using MATH_T = float;
template <typename T>
template <typename T_g, typename T_p>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
@@ -50,16 +50,16 @@ struct AdamFunctor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T *g = (T *)tl.addresses[0][tensor_loc];
T_g *g = (T_g *)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
T *p = (T *)tl.addresses[1][tensor_loc];
T_p *p = (T_p *)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
T *m = (T *)tl.addresses[2][tensor_loc];
T_p *m = (T_p *)tl.addresses[2][tensor_loc];
m += chunk_idx * chunk_size;
T *v = (T *)tl.addresses[3][tensor_loc];
T_p *v = (T_p *)tl.addresses[3][tensor_loc];
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
@@ -155,15 +155,15 @@ void multi_tensor_adam_cuda(
bias_correction2 = 1 - std::pow(beta2, step);
}
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "adam",
DISPATCH_FLOAT_AND_HALF_FOR_G_P(
tensor_lists[0][0].scalar_type(),
tensor_lists[1][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0>(),
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(),
beta1,
beta2,
bias_correction1,

View File

@@ -173,6 +173,36 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \
{ \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} \
else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \
{ \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \
{ \
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \
{ \
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} \
else \
{ \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \
} \
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
T val,