mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-31 03:15:40 +00:00
[Inference] Finish Online Serving Test, add streaming output api, continuous batching test and example (#5432)
* finish online test and add examples * fix test_contionus_batching * fix some bugs * fix bash * fix * fix inference * finish revision * fix typos * revision
This commit is contained in:
parent
69cd7e069d
commit
de378cd2ab
@ -1,13 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from logging import Logger
|
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type
|
||||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
|
|
||||||
|
|
||||||
from colossalai.inference.core.engine import InferenceEngine
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
|
|
||||||
|
# CLI logger
|
||||||
class AsyncEngineDeadError(RuntimeError):
|
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
pass
|
logger = logging.getLogger("colossalai-inference")
|
||||||
|
|
||||||
|
|
||||||
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
|
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
|
||||||
@ -18,54 +18,45 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise AsyncEngineDeadError(msg + " See stack trace above for the actual cause.") from exc
|
raise RuntimeError(msg + " See stack trace above for the actual cause.") from exc
|
||||||
raise AsyncEngineDeadError(msg)
|
raise RuntimeError(msg)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
request_tracker.propagate_exception(exc)
|
request_tracker.propagate_exception(exc)
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
class AsyncStream:
|
class RequstStream:
|
||||||
"""A stream of Output for a request that can be
|
"""A stream of Output for a request that can be
|
||||||
iterated over asynchronously."""
|
iterated over asynchronously."""
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: int) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self._queue = asyncio.Queue()
|
self._future = asyncio.Future()
|
||||||
self._finished = False
|
|
||||||
|
|
||||||
def put(self, item) -> None:
|
def set_result(self, result) -> None:
|
||||||
if self._finished:
|
"""Set final result and signal taht it's ready"""
|
||||||
return
|
if not self._future.done():
|
||||||
self._queue.put_nowait(item)
|
self._future.set_result(result)
|
||||||
|
|
||||||
def finish(self) -> None:
|
async def get_result(self):
|
||||||
self._queue.put_nowait(StopIteration)
|
"""Wait for the result to be set and return it."""
|
||||||
self._finished = True
|
return await self._future
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
return self._finished
|
"""Check if the stream has finished by checking if the future is done."""
|
||||||
|
return self._future.done()
|
||||||
def __aiter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self):
|
|
||||||
result = await self._queue.get()
|
|
||||||
if result is StopIteration:
|
|
||||||
raise StopAsyncIteration
|
|
||||||
elif isinstance(result, Exception):
|
|
||||||
raise result
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class RequestTracker:
|
class Tracer:
|
||||||
"""Synchronous abstraction for tracking requests."""
|
"""
|
||||||
|
Recording new requests and finished requests.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._request_streams: Dict[str, AsyncStream] = {}
|
self._request_streams: Dict[int, RequstStream] = {}
|
||||||
self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
|
self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
|
||||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue()
|
self._new_requests: asyncio.Queue[Tuple[RequstStream, dict]] = asyncio.Queue()
|
||||||
self.new_requests_event = None
|
self.new_requests_event = None
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item):
|
||||||
@ -79,19 +70,21 @@ class RequestTracker:
|
|||||||
Propagate an exception to request streams (all if request_id is None).
|
Propagate an exception to request streams (all if request_id is None).
|
||||||
"""
|
"""
|
||||||
if request_id is not None:
|
if request_id is not None:
|
||||||
self._request_streams[request_id].put(exc)
|
self._request_streams[request_id].set_result(exc)
|
||||||
else:
|
else:
|
||||||
for stream in self._request_streams.values():
|
for stream in self._request_streams.values():
|
||||||
stream.put(exc)
|
stream.set_result(exc)
|
||||||
|
|
||||||
def process_finished_request(self, finished_request) -> None:
|
def process_finished_request(self, finished_request) -> None:
|
||||||
"""Process a finished request from the engine."""
|
"""Process a finished request from the engine."""
|
||||||
request_id = finished_request.request_id
|
request_id = finished_request.request_id
|
||||||
|
try:
|
||||||
self._request_streams[request_id].put(finished_request)
|
self._request_streams[request_id].set_result(finished_request)
|
||||||
|
except:
|
||||||
|
raise RuntimeError(f"The request_id {request_id} is not found in our stream, please check")
|
||||||
self.abort_request(request_id)
|
self.abort_request(request_id)
|
||||||
|
|
||||||
def add_request(self, request_id: int, **engine_add_request_kwargs) -> AsyncStream:
|
def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStream:
|
||||||
"""
|
"""
|
||||||
Add a request to be sent to the engine on the next background
|
Add a request to be sent to the engine on the next background
|
||||||
loop iteration.
|
loop iteration.
|
||||||
@ -99,7 +92,7 @@ class RequestTracker:
|
|||||||
if request_id in self._request_streams:
|
if request_id in self._request_streams:
|
||||||
raise KeyError(f"Request {request_id} already exists.")
|
raise KeyError(f"Request {request_id} already exists.")
|
||||||
|
|
||||||
stream = AsyncStream(request_id)
|
stream = RequstStream(request_id)
|
||||||
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
|
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
|
||||||
|
|
||||||
self.new_requests_event.set()
|
self.new_requests_event.set()
|
||||||
@ -109,7 +102,7 @@ class RequestTracker:
|
|||||||
def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
|
def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
|
||||||
"""Abort a request during next background loop iteration."""
|
"""Abort a request during next background loop iteration."""
|
||||||
if verbose:
|
if verbose:
|
||||||
Logger.info(f"Aborted request {request_id}.")
|
logger.info(f"Aborted request {request_id}.")
|
||||||
|
|
||||||
self._finished_requests.put_nowait(request_id)
|
self._finished_requests.put_nowait(request_id)
|
||||||
|
|
||||||
@ -117,7 +110,7 @@ class RequestTracker:
|
|||||||
# The request has already finished or been aborted.
|
# The request has already finished or been aborted.
|
||||||
return
|
return
|
||||||
|
|
||||||
self._request_streams[request_id].finish()
|
self._request_streams[request_id].set_result(None)
|
||||||
|
|
||||||
def get_new_requests(self):
|
def get_new_requests(self):
|
||||||
"""
|
"""
|
||||||
@ -134,30 +127,6 @@ class RequestTracker:
|
|||||||
|
|
||||||
return new_requests
|
return new_requests
|
||||||
|
|
||||||
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[int]]:
|
|
||||||
"""Get the new requests and finished requests to be
|
|
||||||
sent to the engine."""
|
|
||||||
new_requests: List[Dict] = []
|
|
||||||
finished_requests: Set[int] = set()
|
|
||||||
|
|
||||||
while not self._finished_requests.empty():
|
|
||||||
request_id = self._finished_requests.get_nowait()
|
|
||||||
finished_requests.add(request_id)
|
|
||||||
self._request_streams.pop(request_id, None)
|
|
||||||
|
|
||||||
while not self._new_requests.empty():
|
|
||||||
stream, new_request = self._new_requests.get_nowait()
|
|
||||||
if stream.request_id in finished_requests:
|
|
||||||
# The request has already been aborted.
|
|
||||||
stream.finish()
|
|
||||||
continue
|
|
||||||
self._request_streams[stream.request_id] = stream
|
|
||||||
new_requests.append(new_request)
|
|
||||||
|
|
||||||
self.new_requests_event.clear()
|
|
||||||
|
|
||||||
return new_requests, finished_requests
|
|
||||||
|
|
||||||
async def wait_for_new_requests(self):
|
async def wait_for_new_requests(self):
|
||||||
await self.new_requests_event.wait()
|
await self.new_requests_event.wait()
|
||||||
|
|
||||||
@ -194,6 +163,8 @@ class _AsyncInferenceEngine(InferenceEngine):
|
|||||||
self.request_handler.search_tokens(self.generation_config, logits)
|
self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
# Return: List[Sequence]
|
# Return: List[Sequence]
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
for sequence in finished_sequences:
|
||||||
|
sequence.output = self.tokenizer.decode(sequence.output_token_id)
|
||||||
|
|
||||||
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
|
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
|
||||||
|
|
||||||
@ -216,7 +187,7 @@ class AsyncInferenceEngine:
|
|||||||
# reference to the unshielded loop
|
# reference to the unshielded loop
|
||||||
self._background_loop_unshielded = None
|
self._background_loop_unshielded = None
|
||||||
self.start_engine_loop = start_engine_loop
|
self.start_engine_loop = start_engine_loop
|
||||||
self._request_tracker = RequestTracker()
|
self._request_tracer = Tracer()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def background_loop_status(self):
|
def background_loop_status(self):
|
||||||
@ -226,11 +197,11 @@ class AsyncInferenceEngine:
|
|||||||
if self.background_loop_status:
|
if self.background_loop_status:
|
||||||
raise RuntimeError("Existing loop is running")
|
raise RuntimeError("Existing loop is running")
|
||||||
|
|
||||||
self._request_tracker.init_event()
|
self._request_tracer.init_event()
|
||||||
|
|
||||||
self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
|
self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
|
||||||
self._background_loop_unshielded.add_done_callback(
|
self._background_loop_unshielded.add_done_callback(
|
||||||
partial(_raise_exception_on_finish, request_tracker=self._request_tracker)
|
partial(_raise_exception_on_finish, request_tracker=self._request_tracer)
|
||||||
)
|
)
|
||||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||||
|
|
||||||
@ -243,12 +214,13 @@ class AsyncInferenceEngine:
|
|||||||
|
|
||||||
Returns True if there are in-progress requests.
|
Returns True if there are in-progress requests.
|
||||||
"""
|
"""
|
||||||
new_requests = self._request_tracker.get_new_requests()
|
new_requests = self._request_tracer.get_new_requests()
|
||||||
for new_request in new_requests:
|
for new_request in new_requests:
|
||||||
self.engine.add_single_request(**new_request)
|
self.engine.add_single_request(**new_request)
|
||||||
newly_finished_seqs, has_running_requests = await self.engine.async_step()
|
newly_finished_seqs, has_running_requests = await self.engine.async_step()
|
||||||
|
|
||||||
for seq in newly_finished_seqs:
|
for seq in newly_finished_seqs:
|
||||||
self._request_tracker.process_finished_request(seq)
|
self._request_tracer.process_finished_request(seq)
|
||||||
|
|
||||||
return has_running_requests
|
return has_running_requests
|
||||||
|
|
||||||
@ -264,13 +236,13 @@ class AsyncInferenceEngine:
|
|||||||
return self._abort(request_id)
|
return self._abort(request_id)
|
||||||
|
|
||||||
def _abort(self, request_id: int):
|
def _abort(self, request_id: int):
|
||||||
self._request_tracker.abort_request(request_id)
|
self._request_tracer.abort_request(request_id)
|
||||||
|
|
||||||
async def run_engine_loop(self):
|
async def run_engine_loop(self):
|
||||||
processing_requests = False
|
processing_requests = False
|
||||||
while True:
|
while True:
|
||||||
if not processing_requests:
|
if not processing_requests:
|
||||||
await self._request_tracker.wait_for_new_requests()
|
await self._request_tracer.wait_for_new_requests()
|
||||||
processing_requests = await self.step()
|
processing_requests = await self.step()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
@ -279,7 +251,7 @@ class AsyncInferenceEngine:
|
|||||||
request_id: int,
|
request_id: int,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
) -> AsyncStream:
|
) -> RequstStream:
|
||||||
"""
|
"""
|
||||||
Add a request to the background tracker(waitting queue), start the background loop if needed.
|
Add a request to the background tracker(waitting queue), start the background loop if needed.
|
||||||
"""
|
"""
|
||||||
@ -288,7 +260,7 @@ class AsyncInferenceEngine:
|
|||||||
self.start_background_loop()
|
self.start_background_loop()
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Background loop is not running.")
|
raise RuntimeError("Background loop is not running.")
|
||||||
stream = self._request_tracker.add_request(
|
stream = self._request_tracer.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
@ -308,8 +280,7 @@ class AsyncInferenceEngine:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
||||||
async for request_output in stream:
|
return await stream.get_result()
|
||||||
yield request_output
|
|
||||||
|
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
# If there is an exception or coroutine is cancelled, abort the
|
# If there is an exception or coroutine is cancelled, abort the
|
||||||
|
@ -620,10 +620,10 @@ class InferenceEngine:
|
|||||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||||
"input_ids"
|
"input_ids"
|
||||||
]
|
]
|
||||||
print(prompts_token_ids)
|
|
||||||
|
|
||||||
if isinstance(prompts_token_ids, list):
|
if isinstance(prompts_token_ids, list):
|
||||||
pass
|
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||||
|
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
|
||||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||||
prompts_token_ids = prompts_token_ids.tolist()
|
prompts_token_ids = prompts_token_ids.tolist()
|
||||||
else:
|
else:
|
||||||
@ -739,8 +739,6 @@ class InferenceEngine:
|
|||||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
self.request_handler.append_next_tokens(next_tokens)
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
|
|
||||||
print("in step", logits)
|
|
||||||
|
|
||||||
self.request_handler.search_tokens(self.generation_config, logits)
|
self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
|
@ -209,6 +209,7 @@ class RequestHandler:
|
|||||||
break
|
break
|
||||||
|
|
||||||
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
|
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
|
||||||
|
# for now the recycle logic is not working
|
||||||
remove_list.extend(lst[:num_seqs_to_add])
|
remove_list.extend(lst[:num_seqs_to_add])
|
||||||
self.running_list.extend(lst[:num_seqs_to_add])
|
self.running_list.extend(lst[:num_seqs_to_add])
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ async def generate(request: Request) -> Response:
|
|||||||
# Streaming case
|
# Streaming case
|
||||||
def stream_results():
|
def stream_results():
|
||||||
for request_output in results:
|
for request_output in results:
|
||||||
ret = {"text": request_output}
|
ret = {"text": request_output[len(prompt) :]}
|
||||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
@ -71,7 +71,7 @@ async def generate(request: Request) -> Response:
|
|||||||
# Abort the request if the client disconnects.
|
# Abort the request if the client disconnects.
|
||||||
engine.abort(request_id)
|
engine.abort(request_id)
|
||||||
return Response(status_code=499)
|
return Response(status_code=499)
|
||||||
final_output = request_output
|
final_output = request_output[len(prompt) :]
|
||||||
|
|
||||||
assert final_output is not None
|
assert final_output is not None
|
||||||
ret = {"text": final_output}
|
ret = {"text": final_output}
|
||||||
@ -81,11 +81,15 @@ async def generate(request: Request) -> Response:
|
|||||||
@app.post("/v1/completion")
|
@app.post("/v1/completion")
|
||||||
async def create_completion(request: Request):
|
async def create_completion(request: Request):
|
||||||
request_dict = await request.json()
|
request_dict = await request.json()
|
||||||
|
stream = request_dict.pop("stream", False)
|
||||||
generation_config = get_generation_config(request_dict)
|
generation_config = get_generation_config(request_dict)
|
||||||
generator = await completion_serving.create_completion(request, generation_config)
|
result = await completion_serving.create_completion(request, generation_config)
|
||||||
output = tokenizer.decode(generator.output_token_id)
|
|
||||||
ret = {"request_id": generator.request_id, "text": output}
|
ret = {"request_id": result.request_id, "text": result.output}
|
||||||
return ret
|
if stream:
|
||||||
|
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
return JSONResponse(content=ret)
|
||||||
|
|
||||||
|
|
||||||
def get_generation_config(request):
|
def get_generation_config(request):
|
||||||
|
@ -18,18 +18,17 @@ class CompletionServing:
|
|||||||
async def create_completion(self, request, generation_config):
|
async def create_completion(self, request, generation_config):
|
||||||
request_dict = await request.json()
|
request_dict = await request.json()
|
||||||
request_id = id_generator()
|
request_id = id_generator()
|
||||||
|
|
||||||
prompt = request_dict.pop("prompt")
|
prompt = request_dict.pop("prompt")
|
||||||
|
|
||||||
# it is not a intuitive way
|
# it is not a intuitive way
|
||||||
self.engine.engine.generation_config = generation_config
|
self.engine.engine.generation_config = generation_config
|
||||||
result_generator = self.engine.generate(request_id, prompt=prompt)
|
result_generator = self.engine.generate(request_id, prompt=prompt)
|
||||||
|
|
||||||
final_res = None
|
if await request.is_disconnected():
|
||||||
async for res in result_generator:
|
# Abort the request if the client disconnects.
|
||||||
if await request.is_disconnected():
|
await self.engine.abort(request_id)
|
||||||
# Abort the request if the client disconnects.
|
raise RuntimeError("Client disconnected")
|
||||||
await self.engine.abort(request_id)
|
|
||||||
return {"error_msg": "Client disconnected"}
|
|
||||||
final_res = res
|
|
||||||
|
|
||||||
|
final_res = await result_generator
|
||||||
return final_res
|
return final_res
|
||||||
|
@ -61,6 +61,7 @@ class Sequence:
|
|||||||
pad_token_id (int): The pad token id for this inference process.
|
pad_token_id (int): The pad token id for this inference process.
|
||||||
max_output_len (int): Maximum output length.
|
max_output_len (int): Maximum output length.
|
||||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||||
|
output(str): The output of sequence
|
||||||
"""
|
"""
|
||||||
|
|
||||||
request_id: int
|
request_id: int
|
||||||
@ -73,6 +74,7 @@ class Sequence:
|
|||||||
max_output_len: int = 256
|
max_output_len: int = 256
|
||||||
# NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
|
# NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
|
output: str = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.output_token_id = []
|
self.output_token_id = []
|
||||||
|
@ -598,6 +598,8 @@ def decoding_fused_rotary_embedding(
|
|||||||
"""
|
"""
|
||||||
q_total_tokens, q_head_num, head_dim = q.shape
|
q_total_tokens, q_head_num, head_dim = q.shape
|
||||||
assert q.size(0) == k.size(0) == v.size(0)
|
assert q.size(0) == k.size(0) == v.size(0)
|
||||||
|
assert k.size(1) == v.size(1)
|
||||||
|
assert k_cache.size(-1) == v_cache.size(-1)
|
||||||
|
|
||||||
if head_dim >= 512:
|
if head_dim >= 512:
|
||||||
num_warps = 16
|
num_warps = 16
|
||||||
|
30
examples/inference/client/locustfile.py
Normal file
30
examples/inference/client/locustfile.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from locust import HttpUser, between, tag, task
|
||||||
|
|
||||||
|
|
||||||
|
class QuickstartUser(HttpUser):
|
||||||
|
wait_time = between(1, 5)
|
||||||
|
|
||||||
|
@tag("online-generation")
|
||||||
|
@task(5)
|
||||||
|
def completion(self):
|
||||||
|
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
|
||||||
|
|
||||||
|
@tag("online-generation")
|
||||||
|
@task(5)
|
||||||
|
def completion_streaming(self):
|
||||||
|
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
|
||||||
|
|
||||||
|
@tag("offline-generation")
|
||||||
|
@task(5)
|
||||||
|
def generate_stream(self):
|
||||||
|
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"})
|
||||||
|
|
||||||
|
@tag("offline-generation")
|
||||||
|
@task(5)
|
||||||
|
def generate(self):
|
||||||
|
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "False"})
|
||||||
|
|
||||||
|
@tag("online-generation", "offline-generation")
|
||||||
|
@task
|
||||||
|
def get_models(self):
|
||||||
|
self.client.get("/v0/models")
|
24
examples/inference/client/run_locust.sh
Normal file
24
examples/inference/client/run_locust.sh
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
#argument1: model_path
|
||||||
|
|
||||||
|
# launch server
|
||||||
|
model_path=${1:-"lmsys/vicuna-7b-v1.3"}
|
||||||
|
echo "Model Path: $model_path"
|
||||||
|
echo "Starting server..."
|
||||||
|
python -m colossalai.inference.server.api_server --model $model_path &
|
||||||
|
SERVER_PID=$!
|
||||||
|
|
||||||
|
# waiting time
|
||||||
|
sleep 60
|
||||||
|
|
||||||
|
# Run Locust
|
||||||
|
echo "Starting Locust..."
|
||||||
|
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
|
||||||
|
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
|
||||||
|
|
||||||
|
# kill Server
|
||||||
|
echo "Stopping server..."
|
||||||
|
kill $SERVER_PID
|
||||||
|
|
||||||
|
echo "Test and server shutdown completely"
|
89
tests/test_infer/test_continuous_batching.py
Normal file
89
tests/test_infer/test_continuous_batching.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||||
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
def setup_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_inputs(num_sequences, min_length, max_length):
|
||||||
|
sequences = []
|
||||||
|
for _ in range(num_sequences):
|
||||||
|
length = torch.randint(low=min_length, high=max_length + 1, size=(1,)).item()
|
||||||
|
# generating randomly lengthed sequences
|
||||||
|
sequence = torch.randint(10, 30000, size=(length,))
|
||||||
|
sequences.append(sequence)
|
||||||
|
return sequences
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50
|
||||||
|
)
|
||||||
|
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||||
|
setup_seed(20)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||||
|
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len)
|
||||||
|
|
||||||
|
if use_engine:
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template
|
||||||
|
)
|
||||||
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
|
assert inference_engine.generation_config.max_new_tokens == max_output_len
|
||||||
|
inference_engine.add_request(prompts_token_ids=inputs_token_ids)
|
||||||
|
assert inference_engine.request_handler._has_waiting()
|
||||||
|
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||||
|
outputs = inference_engine.generate(generation_config=generation_config)
|
||||||
|
else:
|
||||||
|
if prompt_template:
|
||||||
|
# apply prompt template
|
||||||
|
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||||
|
inputs = inputs.cuda()
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
do_sample=do_sample,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
max_new_tokens=max_output_len,
|
||||||
|
)
|
||||||
|
outputs = model.generate(inputs, generation_config=generation_config)
|
||||||
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
assert len(outputs) == 10 * max_batch_size
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("prompt_template", [None, "llama"])
|
||||||
|
def check_continuous_batching(prompt_template):
|
||||||
|
check_inference_engine(use_engine=True, prompt_template=prompt_template)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
check_continuous_batching()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_continuous_batching():
|
||||||
|
spawn(run_dist, 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_continuous_batching()
|
Loading…
Reference in New Issue
Block a user