fix rmsnorm template function invocation problem(template function partial specialization is not allowed in Cpp) and luckily pass e2e precision test (#5454)

This commit is contained in:
Steve Luo
2024-03-13 16:00:55 +08:00
committed by GitHub
parent 6fd355a5a6
commit ed431de4e4
2 changed files with 79 additions and 35 deletions

View File

@@ -22,11 +22,15 @@ def setup_seed(seed):
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
)
).cuda()
.cuda()
.half()
)
model = model.eval()
inputs = [
@@ -40,7 +44,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
top_k = 50
if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)