model server fix

This commit is contained in:
tuyang.yhj 2023-05-13 21:42:56 +08:00
parent c9b440e381
commit 8afee1070e

View File

@ -36,8 +36,8 @@ class PromptRequest(BaseModel):
prompt: str
temperature: float
max_new_tokens: int
stop: Optional[List[str]] = None
model: str
stop: str = None
class StreamRequest(BaseModel):
model: str
@ -93,9 +93,14 @@ async def api_generate_stream(request: Request):
return StreamingResponse(generator, background=background_tasks)
@app.post("/generate")
def generate(prompt_request: Request):
params = prompt_request.json()
def generate(prompt_request: PromptRequest):
print(prompt_request)
params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop
}
print("Receive prompt: ", params["prompt"])
output = generate_output(model, tokenizer, params, DEVICE)
print("Output: ", output)