From 3d211ff81b8036748798464e92a866a5ba5074bd Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 18 Mar 2024 17:06:05 +0800 Subject: [PATCH] [Inference] Finish Online Serving Test, add streaming output api, continuous batching test and example (#5432) * finish online test and add examples * fix test_contionus_batching * fix some bugs * fix bash * fix * fix inference * finish revision * fix typos * revision --- colossalai/inference/core/async_engine.py | 125 +++++++----------- colossalai/inference/core/engine.py | 7 +- colossalai/inference/core/request_handler.py | 1 + colossalai/inference/server/api_server.py | 16 ++- .../inference/server/completion_service.py | 13 +- colossalai/inference/struct.py | 2 + .../kernel/triton/no_pad_rotary_embedding.py | 2 +- examples/inference/client/locustfile.py | 30 +++++ examples/inference/client/run_locust.sh | 24 ++++ examples/inference/run_llama_inference.py | 98 -------------- tests/test_infer/test_continuous_batching.py | 89 +++++++++++++ 11 files changed, 213 insertions(+), 194 deletions(-) create mode 100644 examples/inference/client/locustfile.py create mode 100644 examples/inference/client/run_locust.sh delete mode 100644 examples/inference/run_llama_inference.py create mode 100644 tests/test_infer/test_continuous_batching.py diff --git a/colossalai/inference/core/async_engine.py b/colossalai/inference/core/async_engine.py index 5be36fada..e23d0b90f 100644 --- a/colossalai/inference/core/async_engine.py +++ b/colossalai/inference/core/async_engine.py @@ -1,13 +1,13 @@ import asyncio +import logging from functools import partial -from logging import Logger -from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type +from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type from colossalai.inference.core.engine import InferenceEngine - -class AsyncEngineDeadError(RuntimeError): - pass +# CLI logger +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger("colossalai-inference") def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None: @@ -18,54 +18,45 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac except asyncio.CancelledError: return except Exception as exc: - raise AsyncEngineDeadError(msg + " See stack trace above for the actual cause.") from exc - raise AsyncEngineDeadError(msg) + raise RuntimeError(msg + " See stack trace above for the actual cause.") from exc + raise RuntimeError(msg) except Exception as exc: request_tracker.propagate_exception(exc) raise exc -class AsyncStream: +class RequstStream: """A stream of Output for a request that can be iterated over asynchronously.""" - def __init__(self, request_id: str) -> None: + def __init__(self, request_id: int) -> None: self.request_id = request_id - self._queue = asyncio.Queue() - self._finished = False + self._future = asyncio.Future() - def put(self, item) -> None: - if self._finished: - return - self._queue.put_nowait(item) + def set_result(self, result) -> None: + """Set final result and signal taht it's ready""" + if not self._future.done(): + self._future.set_result(result) - def finish(self) -> None: - self._queue.put_nowait(StopIteration) - self._finished = True + async def get_result(self): + """Wait for the result to be set and return it.""" + return await self._future @property def finished(self) -> bool: - return self._finished - - def __aiter__(self): - return self - - async def __anext__(self): - result = await self._queue.get() - if result is StopIteration: - raise StopAsyncIteration - elif isinstance(result, Exception): - raise result - return result + """Check if the stream has finished by checking if the future is done.""" + return self._future.done() -class RequestTracker: - """Synchronous abstraction for tracking requests.""" +class Tracer: + """ + Recording new requests and finished requests. + """ def __init__(self) -> None: - self._request_streams: Dict[str, AsyncStream] = {} + self._request_streams: Dict[int, RequstStream] = {} self._finished_requests: asyncio.Queue[int] = asyncio.Queue() - self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[RequstStream, dict]] = asyncio.Queue() self.new_requests_event = None def __contains__(self, item): @@ -79,19 +70,21 @@ class RequestTracker: Propagate an exception to request streams (all if request_id is None). """ if request_id is not None: - self._request_streams[request_id].put(exc) + self._request_streams[request_id].set_result(exc) else: for stream in self._request_streams.values(): - stream.put(exc) + stream.set_result(exc) def process_finished_request(self, finished_request) -> None: """Process a finished request from the engine.""" request_id = finished_request.request_id - - self._request_streams[request_id].put(finished_request) + try: + self._request_streams[request_id].set_result(finished_request) + except: + raise RuntimeError(f"The request_id {request_id} is not found in our stream, please check") self.abort_request(request_id) - def add_request(self, request_id: int, **engine_add_request_kwargs) -> AsyncStream: + def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStream: """ Add a request to be sent to the engine on the next background loop iteration. @@ -99,7 +92,7 @@ class RequestTracker: if request_id in self._request_streams: raise KeyError(f"Request {request_id} already exists.") - stream = AsyncStream(request_id) + stream = RequstStream(request_id) self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs})) self.new_requests_event.set() @@ -109,7 +102,7 @@ class RequestTracker: def abort_request(self, request_id: int, *, verbose: bool = False) -> None: """Abort a request during next background loop iteration.""" if verbose: - Logger.info(f"Aborted request {request_id}.") + logger.info(f"Aborted request {request_id}.") self._finished_requests.put_nowait(request_id) @@ -117,7 +110,7 @@ class RequestTracker: # The request has already finished or been aborted. return - self._request_streams[request_id].finish() + self._request_streams[request_id].set_result(None) def get_new_requests(self): """ @@ -134,30 +127,6 @@ class RequestTracker: return new_requests - def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[int]]: - """Get the new requests and finished requests to be - sent to the engine.""" - new_requests: List[Dict] = [] - finished_requests: Set[int] = set() - - while not self._finished_requests.empty(): - request_id = self._finished_requests.get_nowait() - finished_requests.add(request_id) - self._request_streams.pop(request_id, None) - - while not self._new_requests.empty(): - stream, new_request = self._new_requests.get_nowait() - if stream.request_id in finished_requests: - # The request has already been aborted. - stream.finish() - continue - self._request_streams[stream.request_id] = stream - new_requests.append(new_request) - - self.new_requests_event.clear() - - return new_requests, finished_requests - async def wait_for_new_requests(self): await self.new_requests_event.wait() @@ -194,6 +163,8 @@ class _AsyncInferenceEngine(InferenceEngine): self.request_handler.search_tokens(self.generation_config, logits) # Return: List[Sequence] finished_sequences = self.request_handler.update() + for sequence in finished_sequences: + sequence.output = self.tokenizer.decode(sequence.output_token_id) return finished_sequences, self.request_handler.current_requests_in_batch() > 0 @@ -216,7 +187,7 @@ class AsyncInferenceEngine: # reference to the unshielded loop self._background_loop_unshielded = None self.start_engine_loop = start_engine_loop - self._request_tracker = RequestTracker() + self._request_tracer = Tracer() @property def background_loop_status(self): @@ -226,11 +197,11 @@ class AsyncInferenceEngine: if self.background_loop_status: raise RuntimeError("Existing loop is running") - self._request_tracker.init_event() + self._request_tracer.init_event() self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop()) self._background_loop_unshielded.add_done_callback( - partial(_raise_exception_on_finish, request_tracker=self._request_tracker) + partial(_raise_exception_on_finish, request_tracker=self._request_tracer) ) self.background_loop = asyncio.shield(self._background_loop_unshielded) @@ -243,12 +214,13 @@ class AsyncInferenceEngine: Returns True if there are in-progress requests. """ - new_requests = self._request_tracker.get_new_requests() + new_requests = self._request_tracer.get_new_requests() for new_request in new_requests: self.engine.add_single_request(**new_request) newly_finished_seqs, has_running_requests = await self.engine.async_step() + for seq in newly_finished_seqs: - self._request_tracker.process_finished_request(seq) + self._request_tracer.process_finished_request(seq) return has_running_requests @@ -264,13 +236,13 @@ class AsyncInferenceEngine: return self._abort(request_id) def _abort(self, request_id: int): - self._request_tracker.abort_request(request_id) + self._request_tracer.abort_request(request_id) async def run_engine_loop(self): processing_requests = False while True: if not processing_requests: - await self._request_tracker.wait_for_new_requests() + await self._request_tracer.wait_for_new_requests() processing_requests = await self.step() await asyncio.sleep(0) @@ -279,7 +251,7 @@ class AsyncInferenceEngine: request_id: int, prompt: Optional[str], prompt_token_ids: Optional[List[int]] = None, - ) -> AsyncStream: + ) -> RequstStream: """ Add a request to the background tracker(waitting queue), start the background loop if needed. """ @@ -288,7 +260,7 @@ class AsyncInferenceEngine: self.start_background_loop() else: raise RuntimeError("Background loop is not running.") - stream = self._request_tracker.add_request( + stream = self._request_tracer.add_request( request_id, prompt=prompt, prompt_token_ids=prompt_token_ids, @@ -308,8 +280,7 @@ class AsyncInferenceEngine: """ try: stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids) - async for request_output in stream: - yield request_output + return await stream.get_result() except (Exception, asyncio.CancelledError) as e: # If there is an exception or coroutine is cancelled, abort the diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index c4b1f0165..6ea555996 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -535,10 +535,10 @@ class InferenceEngine: prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ "input_ids" ] - print(prompts_token_ids) if isinstance(prompts_token_ids, list): - pass + if isinstance(prompts_token_ids[0], torch.Tensor): + prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids] elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): prompts_token_ids = prompts_token_ids.tolist() else: @@ -565,7 +565,6 @@ class InferenceEngine: prompt = None else: prompt = prompts[i] - sequence = Sequence( request_id, prompt, @@ -646,8 +645,6 @@ class InferenceEngine: next_tokens = self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.append_next_tokens(next_tokens) - print("in step", logits) - self.request_handler.search_tokens(self.generation_config, logits) finished_sequences = self.request_handler.update() diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index f5c06b39b..d4c4de299 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -209,6 +209,7 @@ class RequestHandler: break num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num) + # for now the recycle logic is not working remove_list.extend(lst[:num_seqs_to_add]) self.running_list.extend(lst[:num_seqs_to_add]) diff --git a/colossalai/inference/server/api_server.py b/colossalai/inference/server/api_server.py index c182c5160..1d3a6b497 100644 --- a/colossalai/inference/server/api_server.py +++ b/colossalai/inference/server/api_server.py @@ -58,7 +58,7 @@ async def generate(request: Request) -> Response: # Streaming case def stream_results(): for request_output in results: - ret = {"text": request_output} + ret = {"text": request_output[len(prompt) :]} yield (json.dumps(ret) + "\0").encode("utf-8") if stream: @@ -71,7 +71,7 @@ async def generate(request: Request) -> Response: # Abort the request if the client disconnects. engine.abort(request_id) return Response(status_code=499) - final_output = request_output + final_output = request_output[len(prompt) :] assert final_output is not None ret = {"text": final_output} @@ -81,11 +81,15 @@ async def generate(request: Request) -> Response: @app.post("/v1/completion") async def create_completion(request: Request): request_dict = await request.json() + stream = request_dict.pop("stream", False) generation_config = get_generation_config(request_dict) - generator = await completion_serving.create_completion(request, generation_config) - output = tokenizer.decode(generator.output_token_id) - ret = {"request_id": generator.request_id, "text": output} - return ret + result = await completion_serving.create_completion(request, generation_config) + + ret = {"request_id": result.request_id, "text": result.output} + if stream: + return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream") + else: + return JSONResponse(content=ret) def get_generation_config(request): diff --git a/colossalai/inference/server/completion_service.py b/colossalai/inference/server/completion_service.py index bb2160009..61833b031 100644 --- a/colossalai/inference/server/completion_service.py +++ b/colossalai/inference/server/completion_service.py @@ -18,18 +18,17 @@ class CompletionServing: async def create_completion(self, request, generation_config): request_dict = await request.json() request_id = id_generator() + prompt = request_dict.pop("prompt") # it is not a intuitive way self.engine.engine.generation_config = generation_config result_generator = self.engine.generate(request_id, prompt=prompt) - final_res = None - async for res in result_generator: - if await request.is_disconnected(): - # Abort the request if the client disconnects. - await self.engine.abort(request_id) - return {"error_msg": "Client disconnected"} - final_res = res + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + raise RuntimeError("Client disconnected") + final_res = await result_generator return final_res diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 9d33c62f8..0cbd62284 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -64,6 +64,7 @@ class Sequence: eos_token_id (int): The eos token id for this inference process. pad_token_id (int): The pad token id for this inference process. max_output_len (int): Maximum output length. + output(str): The output of sequence """ request_id: int @@ -74,6 +75,7 @@ class Sequence: eos_token_id: int pad_token_id: int max_output_len: int = 256 + output: str = None def __post_init__(self): self.output_token_id = [] diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index 4b294a399..b8dc35c59 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -635,7 +635,7 @@ def decoding_fused_rotary_embedding( """ q_total_tokens, q_head_num, head_dim = q.shape assert q.size(0) == k.size(0) == v.size(0) - assert q.size(1) == k.size(1) == v.size(1) + assert k.size(1) == v.size(1) assert k_cache.size(-1) == v_cache.size(-1) if head_dim >= 1024: diff --git a/examples/inference/client/locustfile.py b/examples/inference/client/locustfile.py new file mode 100644 index 000000000..7402a9c04 --- /dev/null +++ b/examples/inference/client/locustfile.py @@ -0,0 +1,30 @@ +from locust import HttpUser, between, tag, task + + +class QuickstartUser(HttpUser): + wait_time = between(1, 5) + + @tag("online-generation") + @task(5) + def completion(self): + self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "False"}) + + @tag("online-generation") + @task(5) + def completion_streaming(self): + self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"}) + + @tag("offline-generation") + @task(5) + def generate_stream(self): + self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"}) + + @tag("offline-generation") + @task(5) + def generate(self): + self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "False"}) + + @tag("online-generation", "offline-generation") + @task + def get_models(self): + self.client.get("/v0/models") diff --git a/examples/inference/client/run_locust.sh b/examples/inference/client/run_locust.sh new file mode 100644 index 000000000..31f4c962e --- /dev/null +++ b/examples/inference/client/run_locust.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +#argument1: model_path + +# launch server +model_path=${1:-"lmsys/vicuna-7b-v1.3"} +echo "Model Path: $model_path" +echo "Starting server..." +python -m colossalai.inference.server.api_server --model $model_path & +SERVER_PID=$! + +# waiting time +sleep 60 + +# Run Locust +echo "Starting Locust..." +echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information." +locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10 + +# kill Server +echo "Stopping server..." +kill $SERVER_PID + +echo "Test and server shutdown completely" diff --git a/examples/inference/run_llama_inference.py b/examples/inference/run_llama_inference.py deleted file mode 100644 index b5228c64e..000000000 --- a/examples/inference/run_llama_inference.py +++ /dev/null @@ -1,98 +0,0 @@ -import argparse - -import torch -import torch.distributed as dist -from transformers import LlamaForCausalLM, LlamaTokenizer - -import colossalai -from colossalai.accelerator import get_accelerator -from colossalai.inference import InferenceEngine -from colossalai.testing import spawn - -INPUT_TEXTS = [ - "What is the longest river in the world?", - "Explain the difference between process and thread in compouter science.", -] - - -def run_inference(args): - llama_model_path = args.model_path - llama_tokenize_path = args.tokenizer_path or args.model_path - - max_input_len = args.max_input_len - max_output_len = args.max_output_len - max_batch_size = args.batch_size - micro_batch_size = args.micro_batch_size - tp_size = args.tp_size - pp_size = args.pp_size - rank = dist.get_rank() - - tokenizer = LlamaTokenizer.from_pretrained(llama_tokenize_path, padding_side="left") - tokenizer.pad_token_id = tokenizer.eos_token_id - - if args.quant is None: - model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.pad_token_id) - elif args.quant == "gptq": - from auto_gptq import AutoGPTQForCausalLM - - model = AutoGPTQForCausalLM.from_quantized( - llama_model_path, inject_fused_attention=False, device=torch.cuda.current_device() - ) - elif args.quant == "smoothquant": - from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM - - model = SmoothLlamaForCausalLM.from_quantized(llama_model_path, model_basename=args.smoothquant_base_name) - model = model.cuda() - - engine = InferenceEngine( - tp_size=tp_size, - pp_size=pp_size, - model=model, - max_input_len=max_input_len, - max_output_len=max_output_len, - max_batch_size=max_batch_size, - micro_batch_size=micro_batch_size, - quant=args.quant, - dtype=args.dtype, - ) - - inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True) - inputs = {k: v.to(get_accelerator().get_current_device()) for k, v in inputs.items()} - outputs = engine.generate(inputs) - - if rank == 0: - output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) - for input_text, output_text in zip(INPUT_TEXTS, output_texts): - print(f"Input: {input_text}") - print(f"Output: {output_text}") - - -def run_tp_pipeline_inference(rank, world_size, port, args): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_inference(args) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("-p", "--model_path", type=str, help="Model path", required=True) - parser.add_argument("-i", "--input", default="What is the longest river in the world?") - parser.add_argument("-t", "--tokenizer_path", type=str, help="Tokenizer path", default=None) - parser.add_argument( - "-q", - "--quant", - type=str, - choices=["gptq", "smoothquant"], - default=None, - help="quantization type: 'gptq' or 'smoothquant'", - ) - parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name") - parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") - parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") - parser.add_argument("-b", "--batch_size", type=int, default=4, help="Maximum batch size") - parser.add_argument("--max_input_len", type=int, default=2048, help="Maximum input length") - parser.add_argument("--max_output_len", type=int, default=64, help="Maximum output length") - parser.add_argument("--micro_batch_size", type=int, default=1, help="Micro batch size") - parser.add_argument("--dtype", default="fp16", type=str) - - args = parser.parse_args() - spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args) diff --git a/tests/test_infer/test_continuous_batching.py b/tests/test_infer/test_continuous_batching.py new file mode 100644 index 000000000..0b0d92c7c --- /dev/null +++ b/tests/test_infer/test_continuous_batching.py @@ -0,0 +1,89 @@ +import random + +import numpy as np +import pytest +import torch +from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM + +import colossalai +from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def generate_inputs(num_sequences, min_length, max_length): + sequences = [] + for _ in range(num_sequences): + length = torch.randint(low=min_length, high=max_length + 1, size=(1,)).item() + # generating randomly lengthed sequences + sequence = torch.randint(10, 30000, size=(length,)) + sequences.append(sequence) + return sequences + + +@parameterize( + "max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50 +) +def check_inference_engine(use_engine=False, prompt_template=None): + setup_seed(20) + tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half() + model = model.eval() + + inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len) + + if use_engine: + inference_config = InferenceConfig( + max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template + ) + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + assert inference_engine.generation_config.max_new_tokens == max_output_len + inference_engine.add_request(prompts_token_ids=inputs_token_ids) + assert inference_engine.request_handler._has_waiting() + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + outputs = inference_engine.generate(generation_config=generation_config) + else: + if prompt_template: + # apply prompt template + inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] + inputs = inputs.cuda() + generation_config = GenerationConfig( + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=max_output_len, + ) + outputs = model.generate(inputs, generation_config=generation_config) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + assert len(outputs) == 10 * max_batch_size + + +@parameterize("prompt_template", [None, "llama"]) +def check_continuous_batching(prompt_template): + check_inference_engine(use_engine=True, prompt_template=prompt_template) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") + check_continuous_batching() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_continuous_batching(): + spawn(run_dist, 1) + + +if __name__ == "__main__": + test_continuous_batching()