mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[Inference] Fix bugs and docs for feat/online-server (#5598)
* fix test bugs
* add do sample test
* del useless lines
* fix comments
* fix tests
* delete version tag
* delete version tag
* add
* del test sever
* fix test
* fix
* Revert "add"
This reverts commit b9305fb024
.
This commit is contained in:
@@ -29,10 +29,24 @@ def generate_inputs(num_sequences, min_length, max_length):
|
||||
|
||||
|
||||
@parameterize(
|
||||
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"max_batch_size": 8,
|
||||
"max_output_len": 512,
|
||||
"max_input_len": 64,
|
||||
"do_sample": False,
|
||||
}
|
||||
],
|
||||
)
|
||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
|
||||
setup_seed(20)
|
||||
max_batch_size = test_config["max_batch_size"]
|
||||
max_input_len = test_config["max_input_len"]
|
||||
max_output_len = test_config["max_output_len"]
|
||||
do_sample = test_config["do_sample"]
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
|
||||
model = model.eval()
|
||||
|
Reference in New Issue
Block a user