mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
model server fix
This commit is contained in:
parent
c9b440e381
commit
8afee1070e
@ -36,8 +36,8 @@ class PromptRequest(BaseModel):
|
|||||||
prompt: str
|
prompt: str
|
||||||
temperature: float
|
temperature: float
|
||||||
max_new_tokens: int
|
max_new_tokens: int
|
||||||
stop: Optional[List[str]] = None
|
model: str
|
||||||
|
stop: str = None
|
||||||
|
|
||||||
class StreamRequest(BaseModel):
|
class StreamRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
@ -93,9 +93,14 @@ async def api_generate_stream(request: Request):
|
|||||||
return StreamingResponse(generator, background=background_tasks)
|
return StreamingResponse(generator, background=background_tasks)
|
||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
def generate(prompt_request: Request):
|
def generate(prompt_request: PromptRequest):
|
||||||
|
print(prompt_request)
|
||||||
params = prompt_request.json()
|
params = {
|
||||||
|
"prompt": prompt_request.prompt,
|
||||||
|
"temperature": prompt_request.temperature,
|
||||||
|
"max_new_tokens": prompt_request.max_new_tokens,
|
||||||
|
"stop": prompt_request.stop
|
||||||
|
}
|
||||||
print("Receive prompt: ", params["prompt"])
|
print("Receive prompt: ", params["prompt"])
|
||||||
output = generate_output(model, tokenizer, params, DEVICE)
|
output = generate_output(model, tokenizer, params, DEVICE)
|
||||||
print("Output: ", output)
|
print("Output: ", output)
|
||||||
|
Loading…
Reference in New Issue
Block a user