[Fix] Fix & Update Inference Tests (compatibility w/ main)

This commit is contained in:
Yuanheng Zhao
2024-05-05 16:28:56 +00:00
parent 56ed09aba5
commit 8754abae24
30 changed files with 32 additions and 30 deletions

View File

@@ -270,7 +270,7 @@ def llama_rmsnorm_forward(
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
class NopadLlamaMLP(ParallelModule, LlamaMLP):
class NopadLlamaMLP(LlamaMLP, ParallelModule):
def __init__(
self,
config: LlamaConfig,
@@ -392,7 +392,7 @@ class NopadLlamaMLP(ParallelModule, LlamaMLP):
return f"gate_up_proj MergedLinear1D_Col: in_features={self.gate_up_weight.shape[1]}x2, out_features={self.gate_up_weight.shape[2]}, bias=False"
class NopadLlamaAttention(ParallelModule, LlamaAttention):
class NopadLlamaAttention(LlamaAttention, ParallelModule):
def __init__(
self,
config: LlamaConfig,