diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 00dc87057..61a651fe7 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -33,6 +33,13 @@ class PromptRequest(BaseModel): stop: Optional[List[str]] = None +class StreamRequest(BaseModel): + model: str + prompt: str + temperature: float + max_new_tokens: int + stop: str + class EmbeddingRequest(BaseModel): prompt: str @@ -64,10 +71,16 @@ def generate_stream_gate(params): @app.post("/generate_stream") -def api_generate_stream(request: Request): +def api_generate_stream(request: StreamRequest): global model_semaphore, global_counter global_counter += 1 - params = request.json() + params = { + "prompt": request.prompt, + "model": request.model, + "temperature": request.temperature, + "max_new_tokens": request.max_new_tokens, + "stop": request.stop + } print(model, tokenizer, params, DEVICE) # if model_semaphore is None: # model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)