diff --git a/colossalai/inference/core/async_engine.py b/colossalai/inference/core/async_engine.py index 6f7ab15d8..03f7f13f2 100644 --- a/colossalai/inference/core/async_engine.py +++ b/colossalai/inference/core/async_engine.py @@ -4,6 +4,7 @@ from functools import partial from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type from colossalai.inference.core.engine import InferenceEngine +from colossalai.inference.sampler import search_tokens # CLI logger logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") @@ -168,26 +169,44 @@ class _AsyncInferenceEngine(InferenceEngine): generated results. """ batch = self.request_handler.schedule() + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + loop = asyncio.get_running_loop() + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model + # Use run_in_executor to asyncally run the sync method model.forward(). logits = await loop.run_in_executor( None, - self.model, - batch, + model_executable, + input_token_ids, + output_tensor, + input_meta_data, self.k_cache, self.v_cache, ) if self.inference_config.pad_input: logits = logits[:, -1, :] - self.request_handler.search_tokens(self.generation_config, logits) + next_tokens = search_tokens( + self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids + ) + self.request_handler.append_next_tokens(next_tokens) finished_sequences = self.request_handler.update() + for sequence in finished_sequences: sequence.output = self.tokenizer.decode(sequence.output_token_id) - return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0 + return finished_sequences, not self.request_handler.running_list.is_empty() + + def add_single_request(self, request_id: int, prompt: str, prompt_token_ids, generation_config=None): + prompts = [prompt] + gen_config_dict = generation_config.to_dict() if generation_config is not None else {} + self.add_request(request_ids=request_id, prompts=prompts, prompts_token_ids=prompt_token_ids, **gen_config_dict) class AsyncInferenceEngine: @@ -240,7 +259,6 @@ class AsyncInferenceEngine: for new_request in new_requests: self.engine.add_single_request(**new_request) newly_finished_seqs, has_running_requests = await self.engine.async_step() - for seq in newly_finished_seqs: self._request_tracer.process_finished_request(seq) @@ -273,6 +291,7 @@ class AsyncInferenceEngine: request_id: int, prompt: Optional[str], prompt_token_ids: Optional[List[int]] = None, + generation_config=None, ) -> RequstStream: """ Add a request to the background tracker(waiting queue), start the background loop if needed. @@ -286,6 +305,7 @@ class AsyncInferenceEngine: request_id, prompt=prompt, prompt_token_ids=prompt_token_ids, + generation_config=generation_config, ) return stream @@ -294,13 +314,16 @@ class AsyncInferenceEngine: request_id: int, prompt: Optional[str], prompt_token_ids: Optional[List[int]] = None, + generation_config=None, ) -> AsyncIterator[str]: """ Generate output from a request. It receives the request from http server, adds it into the waitting queue of Async Engine and streams the output sequence. """ try: - stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids) + stream = await self.add_request( + request_id, prompt, prompt_token_ids=prompt_token_ids, generation_config=generation_config + ) return await stream.get_result() except (Exception, asyncio.CancelledError) as e: diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 047d7d79f..73ba08750 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -154,7 +154,6 @@ class InferenceEngine: else: model_type = "nopadding_" + self.model_config.model_type model_policy = model_policy_map[model_type]() - pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) tp_group = pg_mesh.get_group_along_axis(TP_AXIS) @@ -589,7 +588,7 @@ class InferenceEngine: def add_request( self, request_ids: Union[List[int], int] = None, - prompts: List[str] = None, + prompts: Union[List[str], str] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, **kwargs, ) -> None: diff --git a/colossalai/inference/server/api_server.py b/colossalai/inference/server/api_server.py index dfbd2c906..91c77ed35 100644 --- a/colossalai/inference/server/api_server.py +++ b/colossalai/inference/server/api_server.py @@ -20,10 +20,12 @@ from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from transformers import AutoModelForCausalLM, AutoTokenizer +import colossalai from colossalai.inference.config import InferenceConfig from colossalai.inference.server.chat_service import ChatServing from colossalai.inference.server.completion_service import CompletionServing from colossalai.inference.server.utils import id_generator +from colossalai.inference.utils import find_available_ports from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa @@ -54,8 +56,9 @@ async def generate(request: Request) -> Response: """ request_dict = await request.json() prompt = request_dict.pop("prompt") - stream = request_dict.pop("stream", "false").lower() - + stream = request_dict.pop("stream", "false") + if isinstance(stream, str): + stream = stream.lower() request_id = id_generator() generation_config = get_generation_config(request_dict) results = engine.generate(request_id, prompt, generation_config=generation_config) @@ -66,7 +69,7 @@ async def generate(request: Request) -> Response: ret = {"text": request_output[len(prompt) :]} yield (json.dumps(ret) + "\0").encode("utf-8") - if stream == "true": + if stream == "true" or stream == True: return StreamingResponse(stream_results()) # Non-streaming case @@ -86,12 +89,14 @@ async def generate(request: Request) -> Response: @app.post("/completion") async def create_completion(request: Request): request_dict = await request.json() - stream = request_dict.pop("stream", "false").lower() + stream = request_dict.pop("stream", "false") + if isinstance(stream, str): + stream = stream.lower() generation_config = get_generation_config(request_dict) result = await completion_serving.create_completion(request, generation_config) ret = {"request_id": result.request_id, "text": result.output} - if stream == "true": + if stream == "true" or stream == True: return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream") else: return JSONResponse(content=ret) @@ -101,10 +106,12 @@ async def create_completion(request: Request): async def create_chat(request: Request): request_dict = await request.json() - stream = request_dict.get("stream", "false").lower() + stream = request_dict.get("stream", "false") + if isinstance(stream, str): + stream = stream.lower() generation_config = get_generation_config(request_dict) message = await chat_serving.create_chat(request, generation_config) - if stream == "true": + if stream == "true" or stream == True: return StreamingResponse(content=message, media_type="text/event-stream") else: ret = {"role": message.role, "text": message.content} @@ -115,27 +122,29 @@ def get_generation_config(request): generation_config = async_engine.engine.generation_config for arg in request: if hasattr(generation_config, arg): - generation_config[arg] = request[arg] + setattr(generation_config, arg, request[arg]) return generation_config def add_engine_config(parser): - parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use") - parser.add_argument( - "--max-model-len", - type=int, - default=None, - help="model context length. If unspecified, " "will be automatically derived from the model.", + "-m", "--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use" ) - # Parallel arguments - parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas") + # Parallel arguments not supported now # KV cache arguments parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size") parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size") + parser.add_argument("-i", "--max_input_len", type=int, default=128, help="max input length") + + parser.add_argument("-o", "--max_output_len", type=int, default=128, help="max output length") + + parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"]) + + parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default") + # generation arguments parser.add_argument( "--prompt_template", @@ -150,7 +159,7 @@ def parse_args(): parser = argparse.ArgumentParser(description="Colossal-Inference API server.") parser.add_argument("--host", type=str, default="127.0.0.1") - parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--port", type=int, default=8000, help="port of FastAPI server.") parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None) parser.add_argument( @@ -164,6 +173,7 @@ def parse_args(): "specified, the model name will be the same as " "the huggingface name.", ) + parser.add_argument( "--chat-template", type=str, @@ -184,13 +194,21 @@ def parse_args(): if __name__ == "__main__": args = parse_args() inference_config = InferenceConfig.from_dict(vars(args)) - model = AutoModelForCausalLM.from_pretrained(args.model) tokenizer = AutoTokenizer.from_pretrained(args.model) + colossalai_backend_port = find_available_ports(1)[0] + colossalai.launch( + rank=0, + world_size=1, + host=args.host, + port=colossalai_backend_port, + backend="nccl", + ) + model = AutoModelForCausalLM.from_pretrained(args.model) async_engine = AsyncInferenceEngine( - start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config + start_engine_loop=True, model_or_path=model, tokenizer=tokenizer, inference_config=inference_config ) engine = async_engine.engine - completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__) + completion_serving = CompletionServing(async_engine, model.__class__.__name__) chat_serving = ChatServing( async_engine, served_model=model.__class__.__name__, diff --git a/colossalai/inference/server/completion_service.py b/colossalai/inference/server/completion_service.py index 61833b031..16111dad4 100644 --- a/colossalai/inference/server/completion_service.py +++ b/colossalai/inference/server/completion_service.py @@ -23,7 +23,7 @@ class CompletionServing: # it is not a intuitive way self.engine.engine.generation_config = generation_config - result_generator = self.engine.generate(request_id, prompt=prompt) + result_generator = self.engine.generate(request_id, prompt=prompt, generation_config=generation_config) if await request.is_disconnected(): # Abort the request if the client disconnects. diff --git a/examples/inference/client/run_locust.sh b/examples/inference/client/run_locust.sh index fe742fda9..ab0a267de 100644 --- a/examples/inference/client/run_locust.sh +++ b/examples/inference/client/run_locust.sh @@ -6,8 +6,9 @@ model_path=${1:-"lmsys/vicuna-7b-v1.3"} chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" echo "Model Path: $model_path" +echo "Chat Tempelate" "${chat_template}" echo "Starting server..." -python -m colossalai.inference.server.api_server --model $model_path --chat-template $chat_template & +python -m colossalai.inference.server.api_server --model $model_path --chat-template "${chat_template}" & SERVER_PID=$! # waiting time @@ -17,9 +18,9 @@ sleep 60 echo "Starting Locust..." echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information." echo "Test completion api first" -locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10 +locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 300 --stop-timeout 10 echo "Test chat api" -locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10 +locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 300 --stop-timeout 10 # kill Server echo "Stopping server..." kill $SERVER_PID