[Inference] Remove unnecessary float4_ and rename float8_ to float8 (#5679)

This commit is contained in:
Steve Luo
2024-05-06 10:55:34 +08:00
committed by GitHub
parent 537a3cbc4d
commit 725fbd2ed0
7 changed files with 112 additions and 147 deletions

View File

@@ -283,11 +283,14 @@ void rms_layernorm(
case 4:
RMSNORM_LAUNCHER(4, block);
break;
case 5:
RMSNORM_LAUNCHER(5, block);
break;
case 8:
RMSNORM_LAUNCHER(8, block);
break;
default:
AT_ERROR("unroll_factor must be 1, 2, 3, 4 or 8");
AT_ERROR("unroll_factor must be 1, 2, 3, 4, 5 or 8");
}
}
}
@@ -330,11 +333,14 @@ void fused_add_rms_layernorm(
case 4:
FUSED_ADD_RMSNORM_LAUNCHER(4, block);
break;
case 5:
FUSED_ADD_RMSNORM_LAUNCHER(5, block);
break;
case 8:
FUSED_ADD_RMSNORM_LAUNCHER(8, block);
break;
default:
AT_ERROR("unroll_factor must be 1, 2, 3, 4 or 8");
AT_ERROR("unroll_factor must be 1, 2, 3, 4, 5 or 8");
}
}
}