[Inference/opt]Optimize the mid tensor of RMS Norm (#5350)

* opt rms_norm

* fix bugs in rms_layernorm
This commit is contained in:
yuehuayingxueluo
2024-02-02 15:06:01 +08:00
committed by GitHub
parent 027aa1043f
commit 21ad4a27f9
7 changed files with 34 additions and 35 deletions

View File

@@ -50,12 +50,10 @@ if HAS_TRITON:
tl.store(Y + cols, y.to(tl.float16), mask=mask)
@torch.no_grad()
def rms_layernorm(x, weight, eps):
def rms_layernorm(x, weight, eps, norm_output=None):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor, (total token, hidden_size)
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
y = torch.empty_like(x) if norm_output is None else norm_output
M, N = x.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
@@ -67,5 +65,5 @@ if HAS_TRITON:
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)
# enqueue kernel
_rmsnorm_kernel[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
return y