mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 23:01:38 +00:00
feat(model): Passing stop parameter to proxyllm (#2077)
This commit is contained in:
parent
cf192a5fb7
commit
53ba6259d2
@ -201,7 +201,7 @@ class ModelRequest:
|
|||||||
max_new_tokens: Optional[int] = None
|
max_new_tokens: Optional[int] = None
|
||||||
"""The maximum number of tokens to generate."""
|
"""The maximum number of tokens to generate."""
|
||||||
|
|
||||||
stop: Optional[str] = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
"""The stop condition of the model inference."""
|
"""The stop condition of the model inference."""
|
||||||
stop_token_ids: Optional[List[int]] = None
|
stop_token_ids: Optional[List[int]] = None
|
||||||
"""The stop token ids of the model inference."""
|
"""The stop token ids of the model inference."""
|
||||||
|
@ -60,6 +60,7 @@ class APIServerException(Exception):
|
|||||||
class APISettings(BaseModel):
|
class APISettings(BaseModel):
|
||||||
api_keys: Optional[List[str]] = None
|
api_keys: Optional[List[str]] = None
|
||||||
embedding_bach_size: int = 4
|
embedding_bach_size: int = 4
|
||||||
|
ignore_stop_exceeds_error: bool = False
|
||||||
|
|
||||||
|
|
||||||
api_settings = APISettings()
|
api_settings = APISettings()
|
||||||
@ -146,6 +147,15 @@ def check_requests(request) -> Optional[JSONResponse]:
|
|||||||
ErrorCode.PARAM_OUT_OF_RANGE,
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
f"{request.stop} is not valid under any of the given schemas - 'stop'",
|
f"{request.stop} is not valid under any of the given schemas - 'stop'",
|
||||||
)
|
)
|
||||||
|
if request.stop and isinstance(request.stop, list) and len(request.stop) > 4:
|
||||||
|
# https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop
|
||||||
|
if not api_settings.ignore_stop_exceeds_error:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"Invalid 'stop': array too long. Expected an array with maximum length 4, but got an array with length {len(request.stop)} instead.",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
request.stop = request.stop[:4]
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -581,6 +591,7 @@ def initialize_apiserver(
|
|||||||
port: int = None,
|
port: int = None,
|
||||||
api_keys: List[str] = None,
|
api_keys: List[str] = None,
|
||||||
embedding_batch_size: Optional[int] = None,
|
embedding_batch_size: Optional[int] = None,
|
||||||
|
ignore_stop_exceeds_error: bool = False,
|
||||||
):
|
):
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -614,6 +625,7 @@ def initialize_apiserver(
|
|||||||
|
|
||||||
if embedding_batch_size:
|
if embedding_batch_size:
|
||||||
api_settings.embedding_bach_size = embedding_batch_size
|
api_settings.embedding_bach_size = embedding_batch_size
|
||||||
|
api_settings.ignore_stop_exceeds_error = ignore_stop_exceeds_error
|
||||||
|
|
||||||
app.include_router(router, prefix="/api", tags=["APIServer"])
|
app.include_router(router, prefix="/api", tags=["APIServer"])
|
||||||
|
|
||||||
@ -664,6 +676,7 @@ def run_apiserver():
|
|||||||
port=apiserver_params.port,
|
port=apiserver_params.port,
|
||||||
api_keys=api_keys,
|
api_keys=api_keys,
|
||||||
embedding_batch_size=apiserver_params.embedding_batch_size,
|
embedding_batch_size=apiserver_params.embedding_batch_size,
|
||||||
|
ignore_stop_exceeds_error=apiserver_params.ignore_stop_exceeds_error,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel
|
from dbgpt._private.pydantic import BaseModel
|
||||||
from dbgpt.core.interface.message import ModelMessage
|
from dbgpt.core.interface.message import ModelMessage
|
||||||
@ -15,7 +15,7 @@ class PromptRequest(BaseModel):
|
|||||||
prompt: str = None
|
prompt: str = None
|
||||||
temperature: float = None
|
temperature: float = None
|
||||||
max_new_tokens: int = None
|
max_new_tokens: int = None
|
||||||
stop: str = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
stop_token_ids: List[int] = []
|
stop_token_ids: List[int] = []
|
||||||
context_len: int = None
|
context_len: int = None
|
||||||
echo: bool = True
|
echo: bool = True
|
||||||
|
@ -167,6 +167,9 @@ class ModelAPIServerParameters(BaseServerParameters):
|
|||||||
embedding_batch_size: Optional[int] = field(
|
embedding_batch_size: Optional[int] = field(
|
||||||
default=None, metadata={"help": "Embedding batch size"}
|
default=None, metadata={"help": "Embedding batch size"}
|
||||||
)
|
)
|
||||||
|
ignore_stop_exceeds_error: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Ignore exceeds stop words error"}
|
||||||
|
)
|
||||||
|
|
||||||
log_file: Optional[str] = field(
|
log_file: Optional[str] = field(
|
||||||
default="dbgpt_model_apiserver.log",
|
default="dbgpt_model_apiserver.log",
|
||||||
|
@ -39,6 +39,7 @@ async def chatgpt_generate_stream(
|
|||||||
temperature=params.get("temperature"),
|
temperature=params.get("temperature"),
|
||||||
context=context,
|
context=context,
|
||||||
max_new_tokens=params.get("max_new_tokens"),
|
max_new_tokens=params.get("max_new_tokens"),
|
||||||
|
stop=params.get("stop"),
|
||||||
)
|
)
|
||||||
async for r in client.generate_stream(request):
|
async for r in client.generate_stream(request):
|
||||||
yield r
|
yield r
|
||||||
@ -188,6 +189,8 @@ class OpenAILLMClient(ProxyLLMClient):
|
|||||||
payload["temperature"] = request.temperature
|
payload["temperature"] = request.temperature
|
||||||
if request.max_new_tokens:
|
if request.max_new_tokens:
|
||||||
payload["max_tokens"] = request.max_new_tokens
|
payload["max_tokens"] = request.max_new_tokens
|
||||||
|
if request.stop:
|
||||||
|
payload["stop"] = request.stop
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
|
@ -27,6 +27,7 @@ async def deepseek_generate_stream(
|
|||||||
temperature=params.get("temperature"),
|
temperature=params.get("temperature"),
|
||||||
context=context,
|
context=context,
|
||||||
max_new_tokens=params.get("max_new_tokens"),
|
max_new_tokens=params.get("max_new_tokens"),
|
||||||
|
stop=params.get("stop"),
|
||||||
)
|
)
|
||||||
async for r in client.generate_stream(request):
|
async for r in client.generate_stream(request):
|
||||||
yield r
|
yield r
|
||||||
|
@ -46,6 +46,7 @@ def gemini_generate_stream(
|
|||||||
temperature=params.get("temperature"),
|
temperature=params.get("temperature"),
|
||||||
context=context,
|
context=context,
|
||||||
max_new_tokens=params.get("max_new_tokens"),
|
max_new_tokens=params.get("max_new_tokens"),
|
||||||
|
stop=params.get("stop"),
|
||||||
)
|
)
|
||||||
for r in client.sync_generate_stream(request):
|
for r in client.sync_generate_stream(request):
|
||||||
yield r
|
yield r
|
||||||
|
@ -26,6 +26,7 @@ async def moonshot_generate_stream(
|
|||||||
temperature=params.get("temperature"),
|
temperature=params.get("temperature"),
|
||||||
context=context,
|
context=context,
|
||||||
max_new_tokens=params.get("max_new_tokens"),
|
max_new_tokens=params.get("max_new_tokens"),
|
||||||
|
stop=params.get("stop"),
|
||||||
)
|
)
|
||||||
async for r in client.generate_stream(request):
|
async for r in client.generate_stream(request):
|
||||||
yield r
|
yield r
|
||||||
|
@ -47,6 +47,7 @@ def spark_generate_stream(
|
|||||||
temperature=params.get("temperature"),
|
temperature=params.get("temperature"),
|
||||||
context=context,
|
context=context,
|
||||||
max_new_tokens=params.get("max_new_tokens"),
|
max_new_tokens=params.get("max_new_tokens"),
|
||||||
|
stop=params.get("stop"),
|
||||||
)
|
)
|
||||||
for r in client.sync_generate_stream(request):
|
for r in client.sync_generate_stream(request):
|
||||||
yield r
|
yield r
|
||||||
|
@ -21,6 +21,7 @@ def tongyi_generate_stream(
|
|||||||
temperature=params.get("temperature"),
|
temperature=params.get("temperature"),
|
||||||
context=context,
|
context=context,
|
||||||
max_new_tokens=params.get("max_new_tokens"),
|
max_new_tokens=params.get("max_new_tokens"),
|
||||||
|
stop=params.get("stop"),
|
||||||
)
|
)
|
||||||
for r in client.sync_generate_stream(request):
|
for r in client.sync_generate_stream(request):
|
||||||
yield r
|
yield r
|
||||||
@ -96,6 +97,7 @@ class TongyiLLMClient(ProxyLLMClient):
|
|||||||
top_p=0.8,
|
top_p=0.8,
|
||||||
stream=True,
|
stream=True,
|
||||||
result_format="message",
|
result_format="message",
|
||||||
|
stop=request.stop,
|
||||||
)
|
)
|
||||||
for r in res:
|
for r in res:
|
||||||
if r:
|
if r:
|
||||||
|
@ -26,6 +26,7 @@ async def yi_generate_stream(
|
|||||||
temperature=params.get("temperature"),
|
temperature=params.get("temperature"),
|
||||||
context=context,
|
context=context,
|
||||||
max_new_tokens=params.get("max_new_tokens"),
|
max_new_tokens=params.get("max_new_tokens"),
|
||||||
|
stop=params.get("stop"),
|
||||||
)
|
)
|
||||||
async for r in client.generate_stream(request):
|
async for r in client.generate_stream(request):
|
||||||
yield r
|
yield r
|
||||||
|
@ -28,6 +28,7 @@ def zhipu_generate_stream(
|
|||||||
temperature=params.get("temperature"),
|
temperature=params.get("temperature"),
|
||||||
context=context,
|
context=context,
|
||||||
max_new_tokens=params.get("max_new_tokens"),
|
max_new_tokens=params.get("max_new_tokens"),
|
||||||
|
stop=params.get("stop"),
|
||||||
)
|
)
|
||||||
for r in client.sync_generate_stream(request):
|
for r in client.sync_generate_stream(request):
|
||||||
yield r
|
yield r
|
||||||
|
@ -64,7 +64,7 @@ class DocxKnowledge(Knowledge):
|
|||||||
documents = self._loader.load()
|
documents = self._loader.load()
|
||||||
else:
|
else:
|
||||||
docs = []
|
docs = []
|
||||||
_SerializedRelationships.load_from_xml = load_from_xml_v2 # type: ignore
|
_SerializedRelationships.load_from_xml = load_from_xml_v2 # type: ignore
|
||||||
doc = docx.Document(self._path)
|
doc = docx.Document(self._path)
|
||||||
content = []
|
content = []
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user