[Inference] Finish Online Serving Test, add streaming output api, continuous batching test and example (#5432)

* finish online test and add examples

* fix test_contionus_batching

* fix some bugs

* fix bash

* fix

* fix inference

* finish revision

* fix typos

* revision
This commit is contained in:
Jianghai 2024-03-18 17:06:05 +08:00 committed by CjhHa1
parent 69cd7e069d
commit de378cd2ab
10 changed files with 214 additions and 94 deletions

View File

@ -1,13 +1,13 @@
import asyncio import asyncio
import logging
from functools import partial from functools import partial
from logging import Logger from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
from colossalai.inference.core.engine import InferenceEngine from colossalai.inference.core.engine import InferenceEngine
# CLI logger
class AsyncEngineDeadError(RuntimeError): logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
pass logger = logging.getLogger("colossalai-inference")
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None: def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
@ -18,54 +18,45 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
except asyncio.CancelledError: except asyncio.CancelledError:
return return
except Exception as exc: except Exception as exc:
raise AsyncEngineDeadError(msg + " See stack trace above for the actual cause.") from exc raise RuntimeError(msg + " See stack trace above for the actual cause.") from exc
raise AsyncEngineDeadError(msg) raise RuntimeError(msg)
except Exception as exc: except Exception as exc:
request_tracker.propagate_exception(exc) request_tracker.propagate_exception(exc)
raise exc raise exc
class AsyncStream: class RequstStream:
"""A stream of Output for a request that can be """A stream of Output for a request that can be
iterated over asynchronously.""" iterated over asynchronously."""
def __init__(self, request_id: str) -> None: def __init__(self, request_id: int) -> None:
self.request_id = request_id self.request_id = request_id
self._queue = asyncio.Queue() self._future = asyncio.Future()
self._finished = False
def put(self, item) -> None: def set_result(self, result) -> None:
if self._finished: """Set final result and signal taht it's ready"""
return if not self._future.done():
self._queue.put_nowait(item) self._future.set_result(result)
def finish(self) -> None: async def get_result(self):
self._queue.put_nowait(StopIteration) """Wait for the result to be set and return it."""
self._finished = True return await self._future
@property @property
def finished(self) -> bool: def finished(self) -> bool:
return self._finished """Check if the stream has finished by checking if the future is done."""
return self._future.done()
def __aiter__(self):
return self
async def __anext__(self):
result = await self._queue.get()
if result is StopIteration:
raise StopAsyncIteration
elif isinstance(result, Exception):
raise result
return result
class RequestTracker: class Tracer:
"""Synchronous abstraction for tracking requests.""" """
Recording new requests and finished requests.
"""
def __init__(self) -> None: def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {} self._request_streams: Dict[int, RequstStream] = {}
self._finished_requests: asyncio.Queue[int] = asyncio.Queue() self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() self._new_requests: asyncio.Queue[Tuple[RequstStream, dict]] = asyncio.Queue()
self.new_requests_event = None self.new_requests_event = None
def __contains__(self, item): def __contains__(self, item):
@ -79,19 +70,21 @@ class RequestTracker:
Propagate an exception to request streams (all if request_id is None). Propagate an exception to request streams (all if request_id is None).
""" """
if request_id is not None: if request_id is not None:
self._request_streams[request_id].put(exc) self._request_streams[request_id].set_result(exc)
else: else:
for stream in self._request_streams.values(): for stream in self._request_streams.values():
stream.put(exc) stream.set_result(exc)
def process_finished_request(self, finished_request) -> None: def process_finished_request(self, finished_request) -> None:
"""Process a finished request from the engine.""" """Process a finished request from the engine."""
request_id = finished_request.request_id request_id = finished_request.request_id
try:
self._request_streams[request_id].put(finished_request) self._request_streams[request_id].set_result(finished_request)
except:
raise RuntimeError(f"The request_id {request_id} is not found in our stream, please check")
self.abort_request(request_id) self.abort_request(request_id)
def add_request(self, request_id: int, **engine_add_request_kwargs) -> AsyncStream: def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStream:
""" """
Add a request to be sent to the engine on the next background Add a request to be sent to the engine on the next background
loop iteration. loop iteration.
@ -99,7 +92,7 @@ class RequestTracker:
if request_id in self._request_streams: if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.") raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id) stream = RequstStream(request_id)
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs})) self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
self.new_requests_event.set() self.new_requests_event.set()
@ -109,7 +102,7 @@ class RequestTracker:
def abort_request(self, request_id: int, *, verbose: bool = False) -> None: def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration.""" """Abort a request during next background loop iteration."""
if verbose: if verbose:
Logger.info(f"Aborted request {request_id}.") logger.info(f"Aborted request {request_id}.")
self._finished_requests.put_nowait(request_id) self._finished_requests.put_nowait(request_id)
@ -117,7 +110,7 @@ class RequestTracker:
# The request has already finished or been aborted. # The request has already finished or been aborted.
return return
self._request_streams[request_id].finish() self._request_streams[request_id].set_result(None)
def get_new_requests(self): def get_new_requests(self):
""" """
@ -134,30 +127,6 @@ class RequestTracker:
return new_requests return new_requests
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[int]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[Dict] = []
finished_requests: Set[int] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
self._request_streams.pop(request_id, None)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
# The request has already been aborted.
stream.finish()
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests, finished_requests
async def wait_for_new_requests(self): async def wait_for_new_requests(self):
await self.new_requests_event.wait() await self.new_requests_event.wait()
@ -194,6 +163,8 @@ class _AsyncInferenceEngine(InferenceEngine):
self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.search_tokens(self.generation_config, logits)
# Return: List[Sequence] # Return: List[Sequence]
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()
for sequence in finished_sequences:
sequence.output = self.tokenizer.decode(sequence.output_token_id)
return finished_sequences, self.request_handler.current_requests_in_batch() > 0 return finished_sequences, self.request_handler.current_requests_in_batch() > 0
@ -216,7 +187,7 @@ class AsyncInferenceEngine:
# reference to the unshielded loop # reference to the unshielded loop
self._background_loop_unshielded = None self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker() self._request_tracer = Tracer()
@property @property
def background_loop_status(self): def background_loop_status(self):
@ -226,11 +197,11 @@ class AsyncInferenceEngine:
if self.background_loop_status: if self.background_loop_status:
raise RuntimeError("Existing loop is running") raise RuntimeError("Existing loop is running")
self._request_tracker.init_event() self._request_tracer.init_event()
self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop()) self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback( self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish, request_tracker=self._request_tracker) partial(_raise_exception_on_finish, request_tracker=self._request_tracer)
) )
self.background_loop = asyncio.shield(self._background_loop_unshielded) self.background_loop = asyncio.shield(self._background_loop_unshielded)
@ -243,12 +214,13 @@ class AsyncInferenceEngine:
Returns True if there are in-progress requests. Returns True if there are in-progress requests.
""" """
new_requests = self._request_tracker.get_new_requests() new_requests = self._request_tracer.get_new_requests()
for new_request in new_requests: for new_request in new_requests:
self.engine.add_single_request(**new_request) self.engine.add_single_request(**new_request)
newly_finished_seqs, has_running_requests = await self.engine.async_step() newly_finished_seqs, has_running_requests = await self.engine.async_step()
for seq in newly_finished_seqs: for seq in newly_finished_seqs:
self._request_tracker.process_finished_request(seq) self._request_tracer.process_finished_request(seq)
return has_running_requests return has_running_requests
@ -264,13 +236,13 @@ class AsyncInferenceEngine:
return self._abort(request_id) return self._abort(request_id)
def _abort(self, request_id: int): def _abort(self, request_id: int):
self._request_tracker.abort_request(request_id) self._request_tracer.abort_request(request_id)
async def run_engine_loop(self): async def run_engine_loop(self):
processing_requests = False processing_requests = False
while True: while True:
if not processing_requests: if not processing_requests:
await self._request_tracker.wait_for_new_requests() await self._request_tracer.wait_for_new_requests()
processing_requests = await self.step() processing_requests = await self.step()
await asyncio.sleep(0) await asyncio.sleep(0)
@ -279,7 +251,7 @@ class AsyncInferenceEngine:
request_id: int, request_id: int,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
) -> AsyncStream: ) -> RequstStream:
""" """
Add a request to the background tracker(waitting queue), start the background loop if needed. Add a request to the background tracker(waitting queue), start the background loop if needed.
""" """
@ -288,7 +260,7 @@ class AsyncInferenceEngine:
self.start_background_loop() self.start_background_loop()
else: else:
raise RuntimeError("Background loop is not running.") raise RuntimeError("Background loop is not running.")
stream = self._request_tracker.add_request( stream = self._request_tracer.add_request(
request_id, request_id,
prompt=prompt, prompt=prompt,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
@ -308,8 +280,7 @@ class AsyncInferenceEngine:
""" """
try: try:
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids) stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
async for request_output in stream: return await stream.get_result()
yield request_output
except (Exception, asyncio.CancelledError) as e: except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the # If there is an exception or coroutine is cancelled, abort the

View File

@ -620,10 +620,10 @@ class InferenceEngine:
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids" "input_ids"
] ]
print(prompts_token_ids)
if isinstance(prompts_token_ids, list): if isinstance(prompts_token_ids, list):
pass if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist() prompts_token_ids = prompts_token_ids.tolist()
else: else:
@ -739,8 +739,6 @@ class InferenceEngine:
next_tokens = self.request_handler.search_tokens(self.generation_config, logits) next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens) self.request_handler.append_next_tokens(next_tokens)
print("in step", logits)
self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()

View File

@ -209,6 +209,7 @@ class RequestHandler:
break break
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num) num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
# for now the recycle logic is not working
remove_list.extend(lst[:num_seqs_to_add]) remove_list.extend(lst[:num_seqs_to_add])
self.running_list.extend(lst[:num_seqs_to_add]) self.running_list.extend(lst[:num_seqs_to_add])

View File

@ -58,7 +58,7 @@ async def generate(request: Request) -> Response:
# Streaming case # Streaming case
def stream_results(): def stream_results():
for request_output in results: for request_output in results:
ret = {"text": request_output} ret = {"text": request_output[len(prompt) :]}
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
if stream: if stream:
@ -71,7 +71,7 @@ async def generate(request: Request) -> Response:
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
engine.abort(request_id) engine.abort(request_id)
return Response(status_code=499) return Response(status_code=499)
final_output = request_output final_output = request_output[len(prompt) :]
assert final_output is not None assert final_output is not None
ret = {"text": final_output} ret = {"text": final_output}
@ -81,11 +81,15 @@ async def generate(request: Request) -> Response:
@app.post("/v1/completion") @app.post("/v1/completion")
async def create_completion(request: Request): async def create_completion(request: Request):
request_dict = await request.json() request_dict = await request.json()
stream = request_dict.pop("stream", False)
generation_config = get_generation_config(request_dict) generation_config = get_generation_config(request_dict)
generator = await completion_serving.create_completion(request, generation_config) result = await completion_serving.create_completion(request, generation_config)
output = tokenizer.decode(generator.output_token_id)
ret = {"request_id": generator.request_id, "text": output} ret = {"request_id": result.request_id, "text": result.output}
return ret if stream:
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
else:
return JSONResponse(content=ret)
def get_generation_config(request): def get_generation_config(request):

View File

@ -18,18 +18,17 @@ class CompletionServing:
async def create_completion(self, request, generation_config): async def create_completion(self, request, generation_config):
request_dict = await request.json() request_dict = await request.json()
request_id = id_generator() request_id = id_generator()
prompt = request_dict.pop("prompt") prompt = request_dict.pop("prompt")
# it is not a intuitive way # it is not a intuitive way
self.engine.engine.generation_config = generation_config self.engine.engine.generation_config = generation_config
result_generator = self.engine.generate(request_id, prompt=prompt) result_generator = self.engine.generate(request_id, prompt=prompt)
final_res = None if await request.is_disconnected():
async for res in result_generator: # Abort the request if the client disconnects.
if await request.is_disconnected(): await self.engine.abort(request_id)
# Abort the request if the client disconnects. raise RuntimeError("Client disconnected")
await self.engine.abort(request_id)
return {"error_msg": "Client disconnected"}
final_res = res
final_res = await result_generator
return final_res return final_res

View File

@ -61,6 +61,7 @@ class Sequence:
pad_token_id (int): The pad token id for this inference process. pad_token_id (int): The pad token id for this inference process.
max_output_len (int): Maximum output length. max_output_len (int): Maximum output length.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
output(str): The output of sequence
""" """
request_id: int request_id: int
@ -73,6 +74,7 @@ class Sequence:
max_output_len: int = 256 max_output_len: int = 256
# NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future. # NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
ignore_eos: bool = False ignore_eos: bool = False
output: str = None
def __post_init__(self): def __post_init__(self):
self.output_token_id = [] self.output_token_id = []

View File

@ -598,6 +598,8 @@ def decoding_fused_rotary_embedding(
""" """
q_total_tokens, q_head_num, head_dim = q.shape q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0) == v.size(0) assert q.size(0) == k.size(0) == v.size(0)
assert k.size(1) == v.size(1)
assert k_cache.size(-1) == v_cache.size(-1)
if head_dim >= 512: if head_dim >= 512:
num_warps = 16 num_warps = 16

View File

@ -0,0 +1,30 @@
from locust import HttpUser, between, tag, task
class QuickstartUser(HttpUser):
wait_time = between(1, 5)
@tag("online-generation")
@task(5)
def completion(self):
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
@tag("online-generation")
@task(5)
def completion_streaming(self):
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
@tag("offline-generation")
@task(5)
def generate_stream(self):
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"})
@tag("offline-generation")
@task(5)
def generate(self):
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "False"})
@tag("online-generation", "offline-generation")
@task
def get_models(self):
self.client.get("/v0/models")

View File

@ -0,0 +1,24 @@
#!/bin/bash
#argument1: model_path
# launch server
model_path=${1:-"lmsys/vicuna-7b-v1.3"}
echo "Model Path: $model_path"
echo "Starting server..."
python -m colossalai.inference.server.api_server --model $model_path &
SERVER_PID=$!
# waiting time
sleep 60
# Run Locust
echo "Starting Locust..."
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
# kill Server
echo "Stopping server..."
kill $SERVER_PID
echo "Test and server shutdown completely"

View File

@ -0,0 +1,89 @@
import random
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def generate_inputs(num_sequences, min_length, max_length):
sequences = []
for _ in range(num_sequences):
length = torch.randint(low=min_length, high=max_length + 1, size=(1,)).item()
# generating randomly lengthed sequences
sequence = torch.randint(10, 30000, size=(length,))
sequences.append(sequence)
return sequences
@parameterize(
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50
)
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
model = model.eval()
inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len)
if use_engine:
inference_config = InferenceConfig(
max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == max_output_len
inference_engine.add_request(prompts_token_ids=inputs_token_ids)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=max_output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
assert len(outputs) == 10 * max_batch_size
@parameterize("prompt_template", [None, "llama"])
def check_continuous_batching(prompt_template):
check_inference_engine(use_engine=True, prompt_template=prompt_template)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_continuous_batching()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_continuous_batching():
spawn(run_dist, 1)
if __name__ == "__main__":
test_continuous_batching()