mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[Inference] ADD async and sync Api server using FastAPI (#5396)
* add api server * fix * add * add completion service and fix bug * add generation config * revise shardformer * fix bugs * add docstrings and fix some bugs * fix bugs and add choices for prompt template
This commit is contained in:
35
colossalai/inference/server/completion_service.py
Normal file
35
colossalai/inference/server/completion_service.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import asyncio
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||
|
||||
from .utils import id_generator
|
||||
|
||||
|
||||
class CompletionServing:
|
||||
def __init__(self, engine: AsyncInferenceEngine, served_model: str):
|
||||
self.engine = engine
|
||||
self.served_model = served_model
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
async def create_completion(self, request, generation_config):
|
||||
request_dict = await request.json()
|
||||
request_id = id_generator()
|
||||
prompt = request_dict.pop("prompt")
|
||||
|
||||
# it is not a intuitive way
|
||||
self.engine.engine.generation_config = generation_config
|
||||
result_generator = self.engine.generate(request_id, prompt=prompt)
|
||||
|
||||
final_res = None
|
||||
async for res in result_generator:
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
return {"error_msg": "Client disconnected"}
|
||||
final_res = res
|
||||
|
||||
return final_res
|
Reference in New Issue
Block a user