From 53ba6259d2fcb3e46b2e1c664e17983ba3ba99d4 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 18 Oct 2024 14:02:53 +0800 Subject: [PATCH] feat(model): Passing stop parameter to proxyllm (#2077) --- dbgpt/core/interface/llm.py | 2 +- dbgpt/model/cluster/apiserver/api.py | 13 +++++++++++++ dbgpt/model/cluster/base.py | 4 ++-- dbgpt/model/parameter.py | 3 +++ dbgpt/model/proxy/llms/chatgpt.py | 3 +++ dbgpt/model/proxy/llms/deepseek.py | 1 + dbgpt/model/proxy/llms/gemini.py | 1 + dbgpt/model/proxy/llms/moonshot.py | 1 + dbgpt/model/proxy/llms/spark.py | 1 + dbgpt/model/proxy/llms/tongyi.py | 2 ++ dbgpt/model/proxy/llms/yi.py | 1 + dbgpt/model/proxy/llms/zhipu.py | 1 + dbgpt/rag/knowledge/docx.py | 2 +- 13 files changed, 31 insertions(+), 4 deletions(-) diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index 94de92a03..edac25b67 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -201,7 +201,7 @@ class ModelRequest: max_new_tokens: Optional[int] = None """The maximum number of tokens to generate.""" - stop: Optional[str] = None + stop: Optional[Union[str, List[str]]] = None """The stop condition of the model inference.""" stop_token_ids: Optional[List[int]] = None """The stop token ids of the model inference.""" diff --git a/dbgpt/model/cluster/apiserver/api.py b/dbgpt/model/cluster/apiserver/api.py index cf44a2c5e..1bed6057b 100644 --- a/dbgpt/model/cluster/apiserver/api.py +++ b/dbgpt/model/cluster/apiserver/api.py @@ -60,6 +60,7 @@ class APIServerException(Exception): class APISettings(BaseModel): api_keys: Optional[List[str]] = None embedding_bach_size: int = 4 + ignore_stop_exceeds_error: bool = False api_settings = APISettings() @@ -146,6 +147,15 @@ def check_requests(request) -> Optional[JSONResponse]: ErrorCode.PARAM_OUT_OF_RANGE, f"{request.stop} is not valid under any of the given schemas - 'stop'", ) + if request.stop and isinstance(request.stop, list) and len(request.stop) > 4: + # https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop + if not api_settings.ignore_stop_exceeds_error: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"Invalid 'stop': array too long. Expected an array with maximum length 4, but got an array with length {len(request.stop)} instead.", + ) + else: + request.stop = request.stop[:4] return None @@ -581,6 +591,7 @@ def initialize_apiserver( port: int = None, api_keys: List[str] = None, embedding_batch_size: Optional[int] = None, + ignore_stop_exceeds_error: bool = False, ): import os @@ -614,6 +625,7 @@ def initialize_apiserver( if embedding_batch_size: api_settings.embedding_bach_size = embedding_batch_size + api_settings.ignore_stop_exceeds_error = ignore_stop_exceeds_error app.include_router(router, prefix="/api", tags=["APIServer"]) @@ -664,6 +676,7 @@ def run_apiserver(): port=apiserver_params.port, api_keys=api_keys, embedding_batch_size=apiserver_params.embedding_batch_size, + ignore_stop_exceeds_error=apiserver_params.ignore_stop_exceeds_error, ) diff --git a/dbgpt/model/cluster/base.py b/dbgpt/model/cluster/base.py index 4106d6c97..7bd780797 100644 --- a/dbgpt/model/cluster/base.py +++ b/dbgpt/model/cluster/base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from dbgpt._private.pydantic import BaseModel from dbgpt.core.interface.message import ModelMessage @@ -15,7 +15,7 @@ class PromptRequest(BaseModel): prompt: str = None temperature: float = None max_new_tokens: int = None - stop: str = None + stop: Optional[Union[str, List[str]]] = None stop_token_ids: List[int] = [] context_len: int = None echo: bool = True diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index ae8168dac..065f74124 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -167,6 +167,9 @@ class ModelAPIServerParameters(BaseServerParameters): embedding_batch_size: Optional[int] = field( default=None, metadata={"help": "Embedding batch size"} ) + ignore_stop_exceeds_error: Optional[bool] = field( + default=False, metadata={"help": "Ignore exceeds stop words error"} + ) log_file: Optional[str] = field( default="dbgpt_model_apiserver.log", diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py index afbb22ace..1b2c2135a 100755 --- a/dbgpt/model/proxy/llms/chatgpt.py +++ b/dbgpt/model/proxy/llms/chatgpt.py @@ -39,6 +39,7 @@ async def chatgpt_generate_stream( temperature=params.get("temperature"), context=context, max_new_tokens=params.get("max_new_tokens"), + stop=params.get("stop"), ) async for r in client.generate_stream(request): yield r @@ -188,6 +189,8 @@ class OpenAILLMClient(ProxyLLMClient): payload["temperature"] = request.temperature if request.max_new_tokens: payload["max_tokens"] = request.max_new_tokens + if request.stop: + payload["stop"] = request.stop return payload async def generate( diff --git a/dbgpt/model/proxy/llms/deepseek.py b/dbgpt/model/proxy/llms/deepseek.py index fe614a8fe..6823acc51 100644 --- a/dbgpt/model/proxy/llms/deepseek.py +++ b/dbgpt/model/proxy/llms/deepseek.py @@ -27,6 +27,7 @@ async def deepseek_generate_stream( temperature=params.get("temperature"), context=context, max_new_tokens=params.get("max_new_tokens"), + stop=params.get("stop"), ) async for r in client.generate_stream(request): yield r diff --git a/dbgpt/model/proxy/llms/gemini.py b/dbgpt/model/proxy/llms/gemini.py index 7a50b5e93..f37f0b2d2 100644 --- a/dbgpt/model/proxy/llms/gemini.py +++ b/dbgpt/model/proxy/llms/gemini.py @@ -46,6 +46,7 @@ def gemini_generate_stream( temperature=params.get("temperature"), context=context, max_new_tokens=params.get("max_new_tokens"), + stop=params.get("stop"), ) for r in client.sync_generate_stream(request): yield r diff --git a/dbgpt/model/proxy/llms/moonshot.py b/dbgpt/model/proxy/llms/moonshot.py index e4eac390a..ecf6474fd 100644 --- a/dbgpt/model/proxy/llms/moonshot.py +++ b/dbgpt/model/proxy/llms/moonshot.py @@ -26,6 +26,7 @@ async def moonshot_generate_stream( temperature=params.get("temperature"), context=context, max_new_tokens=params.get("max_new_tokens"), + stop=params.get("stop"), ) async for r in client.generate_stream(request): yield r diff --git a/dbgpt/model/proxy/llms/spark.py b/dbgpt/model/proxy/llms/spark.py index 2f0847ead..8c31a70f8 100644 --- a/dbgpt/model/proxy/llms/spark.py +++ b/dbgpt/model/proxy/llms/spark.py @@ -47,6 +47,7 @@ def spark_generate_stream( temperature=params.get("temperature"), context=context, max_new_tokens=params.get("max_new_tokens"), + stop=params.get("stop"), ) for r in client.sync_generate_stream(request): yield r diff --git a/dbgpt/model/proxy/llms/tongyi.py b/dbgpt/model/proxy/llms/tongyi.py index 40143db76..0709f3fec 100644 --- a/dbgpt/model/proxy/llms/tongyi.py +++ b/dbgpt/model/proxy/llms/tongyi.py @@ -21,6 +21,7 @@ def tongyi_generate_stream( temperature=params.get("temperature"), context=context, max_new_tokens=params.get("max_new_tokens"), + stop=params.get("stop"), ) for r in client.sync_generate_stream(request): yield r @@ -96,6 +97,7 @@ class TongyiLLMClient(ProxyLLMClient): top_p=0.8, stream=True, result_format="message", + stop=request.stop, ) for r in res: if r: diff --git a/dbgpt/model/proxy/llms/yi.py b/dbgpt/model/proxy/llms/yi.py index ee795eae8..990b1e489 100644 --- a/dbgpt/model/proxy/llms/yi.py +++ b/dbgpt/model/proxy/llms/yi.py @@ -26,6 +26,7 @@ async def yi_generate_stream( temperature=params.get("temperature"), context=context, max_new_tokens=params.get("max_new_tokens"), + stop=params.get("stop"), ) async for r in client.generate_stream(request): yield r diff --git a/dbgpt/model/proxy/llms/zhipu.py b/dbgpt/model/proxy/llms/zhipu.py index bf33f32ef..294d8d1a7 100644 --- a/dbgpt/model/proxy/llms/zhipu.py +++ b/dbgpt/model/proxy/llms/zhipu.py @@ -28,6 +28,7 @@ def zhipu_generate_stream( temperature=params.get("temperature"), context=context, max_new_tokens=params.get("max_new_tokens"), + stop=params.get("stop"), ) for r in client.sync_generate_stream(request): yield r diff --git a/dbgpt/rag/knowledge/docx.py b/dbgpt/rag/knowledge/docx.py index 7c1ecaa9f..15cef353a 100644 --- a/dbgpt/rag/knowledge/docx.py +++ b/dbgpt/rag/knowledge/docx.py @@ -64,7 +64,7 @@ class DocxKnowledge(Knowledge): documents = self._loader.load() else: docs = [] - _SerializedRelationships.load_from_xml = load_from_xml_v2 # type: ignore + _SerializedRelationships.load_from_xml = load_from_xml_v2 # type: ignore doc = docx.Document(self._path) content = []