mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[Inference]Fix readme and example for API server (#5742)
* fix chatapi readme and example * updating doc * add an api and change the doc * remove * add credits and del 'API' heading * readme * readme
This commit is contained in:
@@ -30,7 +30,6 @@ from colossalai.inference.utils import find_available_ports
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
supported_models_dict = {"Llama_Models": ("llama2-7b",)}
|
||||
prompt_template_choices = ["llama", "vicuna"]
|
||||
async_engine = None
|
||||
chat_serving = None
|
||||
@@ -39,15 +38,25 @@ completion_serving = None
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# NOTE: (CjhHa1) models are still under development, need to be updated
|
||||
@app.get("/models")
|
||||
def get_available_models() -> Response:
|
||||
return JSONResponse(supported_models_dict)
|
||||
@app.get("/ping")
|
||||
def health_check() -> JSONResponse:
|
||||
"""Health Check for server."""
|
||||
return JSONResponse({"status": "Healthy"})
|
||||
|
||||
|
||||
@app.get("/engine_check")
|
||||
def engine_check() -> bool:
|
||||
"""Check if the background loop is running."""
|
||||
loop_status = async_engine.background_loop_status
|
||||
if loop_status == False:
|
||||
return JSONResponse({"status": "Error"})
|
||||
return JSONResponse({"status": "Running"})
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
NOTE: THIS API IS USED ONLY FOR TESTING, DO NOT USE THIS IF YOU ARE IN ACTUAL APPLICATION.
|
||||
|
||||
A request should be a JSON object with the following fields:
|
||||
- prompts: the prompts to use for the generation.
|
||||
@@ -133,7 +142,7 @@ def add_engine_config(parser):
|
||||
# Parallel arguments not supported now
|
||||
|
||||
# KV cache arguments
|
||||
parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size")
|
||||
parser.add_argument("--block_size", type=int, default=16, choices=[16, 32], help="token block size")
|
||||
|
||||
parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size")
|
||||
|
||||
|
Reference in New Issue
Block a user