diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 716c670bf..2920495c5 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -36,8 +36,8 @@ class PromptRequest(BaseModel): prompt: str temperature: float max_new_tokens: int - stop: Optional[List[str]] = None - + model: str + stop: str = None class StreamRequest(BaseModel): model: str @@ -93,9 +93,14 @@ async def api_generate_stream(request: Request): return StreamingResponse(generator, background=background_tasks) @app.post("/generate") -def generate(prompt_request: Request): - - params = prompt_request.json() +def generate(prompt_request: PromptRequest): + print(prompt_request) + params = { + "prompt": prompt_request.prompt, + "temperature": prompt_request.temperature, + "max_new_tokens": prompt_request.max_new_tokens, + "stop": prompt_request.stop + } print("Receive prompt: ", params["prompt"]) output = generate_output(model, tokenizer, params, DEVICE) print("Output: ", output)