From bc9063adf1598c3be32fc2d12577d76b9daa79bf Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 8 May 2024 10:36:42 +0000 Subject: [PATCH] resolve rebase conflicts on Branch feat/online-serving --- colossalai/inference/core/engine.py | 13 +++------ colossalai/inference/server/README.md | 27 +++++++++++++++++++ .../kernel/triton/no_pad_rotary_embedding.py | 2 -- tests/test_infer/test_continuous_batching.py | 2 +- 4 files changed, 31 insertions(+), 13 deletions(-) create mode 100644 colossalai/inference/server/README.md diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 02a8c92a2..1ced54dd7 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -527,16 +527,9 @@ class InferenceEngine: List[str]: Inference result returned by one generation. """ with torch.inference_mode(): -<<<<<<< HEAD - if isinstance(prompts, str) and isinstance(request_ids, int): - prompts = [prompts] - request_ids = [request_ids] -======= - if prompts is not None or prompts_token_ids is not None: - self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) ->>>>>>> [Inference] Fix bugs and docs for feat/online-server (#5598) - + prompts = [prompts] + request_ids = [request_ids] if prompts is not None or prompts_token_ids is not None: gen_config_dict = generation_config.to_dict() if generation_config is not None else {} self.add_request( @@ -545,7 +538,7 @@ class InferenceEngine: prompts_token_ids=prompts_token_ids, **gen_config_dict, ) - + output_seqs_list = [] total_tokens_list = [] diff --git a/colossalai/inference/server/README.md b/colossalai/inference/server/README.md new file mode 100644 index 000000000..8b5f29fc0 --- /dev/null +++ b/colossalai/inference/server/README.md @@ -0,0 +1,27 @@ +# Online Service +Colossal-Inference supports fast-api based online service. Simple completion and chat are both supported. Follow the commands below and +you can simply construct a server with both completion and chat functionalities. For now we only support `Llama` model, we will fullfill +the blank quickly. + +# Usage +```bash +# First, Lauch an API locally. +python3 -m colossalai.inference.server.api_server --model path of your llama2 model --chat_template "{% for message in messages %} +{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" + + +# Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api + +# For completion service, you can invoke it +curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}' + +# For chat service, you can invoke it +curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"converation": + [{"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"},], + "stream": "False",}' +# If you just want to test a simple generation, turn to generate api +curl -X POST http://127.0.0.1:8000/generate -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}' + +``` +We also support streaming output, simply change the `stream` to `True` in the request body. diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index 3a1de6d6a..e0da816bd 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -598,8 +598,6 @@ def decoding_fused_rotary_embedding( """ q_total_tokens, q_head_num, head_dim = q.shape assert q.size(0) == k.size(0) == v.size(0) - assert k.size(1) == v.size(1) - assert k_cache.size(-1) == v_cache.size(-1) if head_dim >= 512: num_warps = 16 diff --git a/tests/test_infer/test_continuous_batching.py b/tests/test_infer/test_continuous_batching.py index 350ed473e..a88798619 100644 --- a/tests/test_infer/test_continuous_batching.py +++ b/tests/test_infer/test_continuous_batching.py @@ -89,7 +89,7 @@ def check_continuous_batching(prompt_template): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") check_continuous_batching()