mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
fix rmsnorm template function invocation problem(template function partial specialization is not allowed in Cpp) and luckily pass e2e precision test (#5454)
This commit is contained in:
@@ -22,11 +22,15 @@ def setup_seed(seed):
|
||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||
model = (
|
||||
LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||
)
|
||||
)
|
||||
).cuda()
|
||||
.cuda()
|
||||
.half()
|
||||
)
|
||||
model = model.eval()
|
||||
|
||||
inputs = [
|
||||
@@ -40,7 +44,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
top_k = 50
|
||||
|
||||
if use_engine:
|
||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
|
||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
|
Reference in New Issue
Block a user