diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py index bfcd89264..f67b78a08 100644 --- a/applications/Chat/inference/server.py +++ b/applications/Chat/inference/server.py @@ -27,6 +27,7 @@ class GenerationTaskReq(BaseModel): top_k: Optional[int] = Field(default=None, gt=0, example=50) top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) + repetition_penalty: Optional[float] = Field(default=None, gt=1.0, example=1.2) limiter = Limiter(key_func=get_remote_address) @@ -55,6 +56,7 @@ app.add_middleware( def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} + #TODO(ver217): streaming generation does not support repetition_penalty now model_kwargs = { 'max_generate_tokens': max_new_tokens, 'early_stopping': True,