model server fix

This commit is contained in:
tuyang.yhj 2023-05-13 21:42:56 +08:00
parent c9b440e381
commit 8afee1070e

View File

@ -36,8 +36,8 @@ class PromptRequest(BaseModel):
prompt: str prompt: str
temperature: float temperature: float
max_new_tokens: int max_new_tokens: int
stop: Optional[List[str]] = None model: str
stop: str = None
class StreamRequest(BaseModel): class StreamRequest(BaseModel):
model: str model: str
@ -93,9 +93,14 @@ async def api_generate_stream(request: Request):
return StreamingResponse(generator, background=background_tasks) return StreamingResponse(generator, background=background_tasks)
@app.post("/generate") @app.post("/generate")
def generate(prompt_request: Request): def generate(prompt_request: PromptRequest):
print(prompt_request)
params = prompt_request.json() params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop
}
print("Receive prompt: ", params["prompt"]) print("Receive prompt: ", params["prompt"])
output = generate_output(model, tokenizer, params, DEVICE) output = generate_output(model, tokenizer, params, DEVICE)
print("Output: ", output) print("Output: ", output)