Optimized the execution interval time between cuda kernels caused by view and memcopy (#5390)

* opt_view_and_memcopy

* fix bugs in ci

* fix ci bugs

* update benchmark scripts

* fix ci bugs
This commit is contained in:
yuehuayingxueluo
2024-02-21 13:23:57 +08:00
committed by GitHub
parent 730103819d
commit 2a718c8be8
8 changed files with 141 additions and 55 deletions

View File

@@ -29,8 +29,10 @@ except:
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output)
def _triton_rmsnorm_forward(
self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None
):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
return _triton_rmsnorm_forward
else: