[feat] cuda graph support and refactor non-functional api

This commit is contained in:
Runyu Lu
2024-03-08 14:19:35 +08:00
parent 593a72e4d5
commit cefaeb5fdd
5 changed files with 281 additions and 43 deletions

View File

@@ -1,5 +1,3 @@
import torch
try:
import triton
import triton.language as tl
@@ -94,7 +92,10 @@ if HAS_TRITON:
def rms_layernorm(x, weight, eps, norm_output=None, residual=None):
# allocate output
y = torch.empty_like(x) if norm_output is None else norm_output
# y = torch.empty_like(x) if norm_output is None else norm_output
y = (
x * 0 if norm_output is None else norm_output
) # to make the operation non-functional, store y as the intermediate activation
M, N = x.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()