[fix] multi graphs capture error

This commit is contained in:
Runyu Lu
2024-03-11 10:49:31 +08:00
parent cefaeb5fdd
commit b2c0d9ff2b
4 changed files with 27 additions and 30 deletions

View File

@@ -92,7 +92,6 @@ if HAS_TRITON:
def rms_layernorm(x, weight, eps, norm_output=None, residual=None):
# allocate output
# y = torch.empty_like(x) if norm_output is None else norm_output
y = (
x * 0 if norm_output is None else norm_output
) # to make the operation non-functional, store y as the intermediate activation