rebuild params

This commit is contained in:
csunny 2023-04-30 21:54:49 +08:00
parent f0e17ed8f1
commit eef244fe92

View File

@ -33,6 +33,13 @@ class PromptRequest(BaseModel):
stop: Optional[List[str]] = None
class StreamRequest(BaseModel):
model: str
prompt: str
temperature: float
max_new_tokens: int
stop: str
class EmbeddingRequest(BaseModel):
prompt: str
@ -64,10 +71,16 @@ def generate_stream_gate(params):
@app.post("/generate_stream")
def api_generate_stream(request: Request):
def api_generate_stream(request: StreamRequest):
global model_semaphore, global_counter
global_counter += 1
params = request.json()
params = {
"prompt": request.prompt,
"model": request.model,
"temperature": request.temperature,
"max_new_tokens": request.max_new_tokens,
"stop": request.stop
}
print(model, tokenizer, params, DEVICE)
# if model_semaphore is None:
# model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)