From a9668eb2e4b77e79ad4e16236dba67a97d8040ae Mon Sep 17 00:00:00 2001 From: Andriy Mulyar Date: Tue, 15 Aug 2023 12:06:49 -0400 Subject: [PATCH] Added optional top_p and top_k --- gpt4all-api/gpt4all_api/app/api_v1/routes/completions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gpt4all-api/gpt4all_api/app/api_v1/routes/completions.py b/gpt4all-api/gpt4all_api/app/api_v1/routes/completions.py index 700650a5..a403faac 100644 --- a/gpt4all-api/gpt4all_api/app/api_v1/routes/completions.py +++ b/gpt4all-api/gpt4all_api/app/api_v1/routes/completions.py @@ -2,7 +2,7 @@ import json from typing import List, Dict, Iterable, AsyncIterable import logging import time -from typing import Dict, List, Union +from typing import Dict, List, Union, Optional from uuid import uuid4 import aiohttp import asyncio @@ -24,8 +24,8 @@ class CompletionRequest(BaseModel): prompt: Union[List[str], str] = Field(..., description='The prompt to begin completing from.') max_tokens: int = Field(None, description='Max tokens to generate') temperature: float = Field(settings.temp, description='Model temperature') - top_p: float = Field(settings.top_p, description='top_p') - top_k: int = Field(settings.top_k, description='top_k') + top_p: Optional[float] = Field(settings.top_p, description='top_p') + top_k: Optional[int] = Field(settings.top_k, description='top_k') n: int = Field(1, description='How many completions to generate for each prompt') stream: bool = Field(False, description='Stream responses') repeat_penalty: float = Field(settings.repeat_penalty, description='Repeat penalty')