mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference/opt]Optimize the mid tensor of RMS Norm (#5350)
* opt rms_norm * fix bugs in rms_layernorm
This commit is contained in:
@@ -50,12 +50,10 @@ if HAS_TRITON:
|
||||
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def rms_layernorm(x, weight, eps):
|
||||
def rms_layernorm(x, weight, eps, norm_output=None):
|
||||
# allocate output
|
||||
y = torch.empty_like(x)
|
||||
# reshape input data into 2D tensor, (total token, hidden_size)
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
y = torch.empty_like(x) if norm_output is None else norm_output
|
||||
M, N = x.shape
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
|
||||
@@ -67,5 +65,5 @@ if HAS_TRITON:
|
||||
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)
|
||||
|
||||
# enqueue kernel
|
||||
_rmsnorm_kernel[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
return y
|
||||
|
Reference in New Issue
Block a user