[Inference/opt]Optimize the mid tensor of RMS Norm (#5350)

* opt rms_norm

* fix bugs in rms_layernorm
This commit is contained in:
yuehuayingxueluo
2024-02-02 15:06:01 +08:00
committed by GitHub
parent 027aa1043f
commit 21ad4a27f9
7 changed files with 34 additions and 35 deletions

View File

@@ -29,8 +29,8 @@ except:
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon)
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output)
return _triton_rmsnorm_forward
else: