[Inference] ADD async and sync Api server using FastAPI (#5396)

* add api server

* fix

* add

* add completion service and fix bug

* add generation config

* revise shardformer

* fix bugs

* add docstrings and fix some bugs

* fix bugs and add choices for prompt template
This commit is contained in:
Jianghai
2024-03-01 14:47:36 +08:00
committed by CjhHa1
parent d482922035
commit 69cd7e069d
13 changed files with 789 additions and 25 deletions

View File

@@ -0,0 +1,35 @@
import asyncio
from colossalai.inference.core.async_engine import AsyncInferenceEngine
from .utils import id_generator
class CompletionServing:
def __init__(self, engine: AsyncInferenceEngine, served_model: str):
self.engine = engine
self.served_model = served_model
try:
asyncio.get_running_loop()
except RuntimeError:
pass
async def create_completion(self, request, generation_config):
request_dict = await request.json()
request_id = id_generator()
prompt = request_dict.pop("prompt")
# it is not a intuitive way
self.engine.engine.generation_config = generation_config
result_generator = self.engine.generate(request_id, prompt=prompt)
final_res = None
async for res in result_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
return {"error_msg": "Client disconnected"}
final_res = res
return final_res