mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-22 17:39:02 +00:00
350 lines
13 KiB
Python
350 lines
13 KiB
Python
import json
|
|
from typing import Any, AsyncGenerator, List, Optional, Union
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
from fastchat.protocol.api_protocol import ChatCompletionResponse
|
|
|
|
from dbgpt.app.openapi.api_view_model import ChatCompletionStreamResponse
|
|
from dbgpt.client.schemas import ChatCompletionRequestBody
|
|
|
|
CLIENT_API_PATH = "/api"
|
|
CLIENT_SERVE_PATH = "/serve"
|
|
|
|
|
|
class ClientException(Exception):
|
|
"""ClientException is raised when an error occurs in the client."""
|
|
|
|
def __init__(self, status=None, reason=None, http_resp=None):
|
|
"""
|
|
Args:
|
|
status: Optional[int], the HTTP status code.
|
|
reason: Optional[str], the reason for the exception.
|
|
http_resp: Optional[httpx.Response], the HTTP response object.
|
|
"""
|
|
reason = json.loads(reason)
|
|
if http_resp:
|
|
self.status = http_resp.status_code
|
|
self.reason = http_resp.content
|
|
self.body = http_resp.content
|
|
self.headers = None
|
|
else:
|
|
self.status = status
|
|
self.reason = reason
|
|
self.body = None
|
|
self.headers = None
|
|
|
|
def __str__(self):
|
|
"""Custom error messages for exception"""
|
|
error_message = "({0})\n" "Reason: {1}\n".format(self.status, self.reason)
|
|
if self.headers:
|
|
error_message += "HTTP response headers: {0}\n".format(self.headers)
|
|
|
|
if self.body:
|
|
error_message += "HTTP response body: {0}\n".format(self.body)
|
|
|
|
return error_message
|
|
|
|
|
|
class Client(object):
|
|
def __init__(
|
|
self,
|
|
api_base: Optional[str] = "http://localhost:5000",
|
|
api_key: Optional[str] = None,
|
|
version: Optional[str] = "v2",
|
|
timeout: Optional[httpx._types.TimeoutTypes] = 120,
|
|
):
|
|
"""
|
|
Args:
|
|
api_base: Optional[str], a full URL for the DB-GPT API. Defaults to the http://localhost:5000.
|
|
api_key: Optional[str], The dbgpt api key to use for authentication. Defaults to None.
|
|
timeout: Optional[httpx._types.TimeoutTypes]: The timeout to use. Defaults to None.
|
|
In most cases, pass in a float number to specify the timeout in seconds.
|
|
Returns:
|
|
None
|
|
Raise: ClientException
|
|
|
|
Examples:
|
|
--------
|
|
.. code-block:: python
|
|
|
|
from dbgpt.client.client import Client
|
|
|
|
DBGPT_API_BASE = "http://localhost:5000"
|
|
DBGPT_API_KEY = "dbgpt"
|
|
client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY)
|
|
client.chat(model="chatgpt_proxyllm", messages="Hello?")
|
|
"""
|
|
if is_valid_url(api_base):
|
|
self._api_url = api_base.rstrip("/")
|
|
else:
|
|
raise ValueError(f"api url {api_base} does not exist or is not accessible.")
|
|
self._api_key = api_key
|
|
self._version = version
|
|
self._api_url = api_base + CLIENT_API_PATH + "/" + version
|
|
self._timeout = timeout
|
|
headers = {"Authorization": f"Bearer {self._api_key}"} if self._api_key else {}
|
|
self._http_client = httpx.AsyncClient(
|
|
headers=headers, timeout=timeout if timeout else httpx.Timeout(None)
|
|
)
|
|
|
|
async def chat(
|
|
self,
|
|
model: str,
|
|
messages: Union[str, List[str]],
|
|
temperature: Optional[float] = None,
|
|
max_new_tokens: Optional[int] = None,
|
|
chat_mode: Optional[str] = None,
|
|
chat_param: Optional[str] = None,
|
|
conv_uid: Optional[str] = None,
|
|
user_name: Optional[str] = None,
|
|
sys_code: Optional[str] = None,
|
|
span_id: Optional[str] = None,
|
|
incremental: bool = True,
|
|
enable_vis: bool = True,
|
|
) -> ChatCompletionResponse:
|
|
"""
|
|
Chat Completion.
|
|
Args:
|
|
model: str, The model name.
|
|
messages: Union[str, List[str]], The user input messages.
|
|
temperature: Optional[float], What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
|
|
max_new_tokens: Optional[int], The maximum number of tokens that can be generated in the chat completion.
|
|
chat_mode: Optional[str], The chat mode.
|
|
chat_param: Optional[str], The chat param of chat mode.
|
|
conv_uid: Optional[str], The conversation id of the model inference.
|
|
user_name: Optional[str], The user name of the model inference.
|
|
sys_code: Optional[str], The system code of the model inference.
|
|
span_id: Optional[str], The span id of the model inference.
|
|
incremental: bool, Used to control whether the content is returned incrementally or in full each time. If this parameter is not provided, the default is full return.
|
|
enable_vis: bool, Response content whether to output vis label.
|
|
Returns:
|
|
ChatCompletionResponse: The chat completion response.
|
|
Examples:
|
|
--------
|
|
.. code-block:: python
|
|
|
|
from dbgpt.client.client import Client
|
|
|
|
DBGPT_API_BASE = "http://localhost:5000"
|
|
DBGPT_API_KEY = "dbgpt"
|
|
client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY)
|
|
res = await client.chat(model="chatgpt_proxyllm", messages="Hello?")
|
|
"""
|
|
request = ChatCompletionRequestBody(
|
|
model=model,
|
|
messages=messages,
|
|
stream=False,
|
|
temperature=temperature,
|
|
max_new_tokens=max_new_tokens,
|
|
chat_mode=chat_mode,
|
|
chat_param=chat_param,
|
|
conv_uid=conv_uid,
|
|
user_name=user_name,
|
|
sys_code=sys_code,
|
|
span_id=span_id,
|
|
incremental=incremental,
|
|
enable_vis=enable_vis,
|
|
)
|
|
response = await self._http_client.post(
|
|
self._api_url + "/chat/completions", json=request.dict()
|
|
)
|
|
if response.status_code == 200:
|
|
json_data = json.loads(response.text)
|
|
chat_completion_response = ChatCompletionResponse(**json_data)
|
|
return chat_completion_response
|
|
else:
|
|
return json.loads(response.content)
|
|
|
|
async def chat_stream(
|
|
self,
|
|
model: str,
|
|
messages: Union[str, List[str]],
|
|
temperature: Optional[float] = None,
|
|
max_new_tokens: Optional[int] = None,
|
|
chat_mode: Optional[str] = None,
|
|
chat_param: Optional[str] = None,
|
|
conv_uid: Optional[str] = None,
|
|
user_name: Optional[str] = None,
|
|
sys_code: Optional[str] = None,
|
|
span_id: Optional[str] = None,
|
|
incremental: bool = True,
|
|
enable_vis: bool = True,
|
|
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
|
|
"""
|
|
Chat Stream Completion.
|
|
Args:
|
|
model: str, The model name.
|
|
messages: Union[str, List[str]], The user input messages.
|
|
temperature: Optional[float], What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
|
|
max_new_tokens: Optional[int], The maximum number of tokens that can be generated in the chat completion.
|
|
chat_mode: Optional[str], The chat mode.
|
|
chat_param: Optional[str], The chat param of chat mode.
|
|
conv_uid: Optional[str], The conversation id of the model inference.
|
|
user_name: Optional[str], The user name of the model inference.
|
|
sys_code: Optional[str], The system code of the model inference.
|
|
span_id: Optional[str], The span id of the model inference.
|
|
incremental: bool, Used to control whether the content is returned incrementally or in full each time. If this parameter is not provided, the default is full return.
|
|
enable_vis: bool, Response content whether to output vis label.
|
|
Returns:
|
|
ChatCompletionStreamResponse: The chat completion response.
|
|
|
|
Examples:
|
|
--------
|
|
.. code-block:: python
|
|
|
|
from dbgpt.client.client import Client
|
|
|
|
DBGPT_API_BASE = "http://localhost:5000"
|
|
DBGPT_API_KEY = "dbgpt"
|
|
client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY)
|
|
res = await client.chat_stream(model="chatgpt_proxyllm", messages="Hello?")
|
|
"""
|
|
request = ChatCompletionRequestBody(
|
|
model=model,
|
|
messages=messages,
|
|
stream=True,
|
|
temperature=temperature,
|
|
max_new_tokens=max_new_tokens,
|
|
chat_mode=chat_mode,
|
|
chat_param=chat_param,
|
|
conv_uid=conv_uid,
|
|
user_name=user_name,
|
|
sys_code=sys_code,
|
|
span_id=span_id,
|
|
incremental=incremental,
|
|
enable_vis=enable_vis,
|
|
)
|
|
async with self._http_client.stream(
|
|
method="POST",
|
|
url=self._api_url + "/chat/completions",
|
|
json=request.dict(),
|
|
headers={},
|
|
) as response:
|
|
if response.status_code == 200:
|
|
async for line in response.aiter_lines():
|
|
try:
|
|
if line == "data: [DONE]\n":
|
|
break
|
|
if line.startswith("data:"):
|
|
json_data = json.loads(line[len("data: ") :])
|
|
chat_completion_response = ChatCompletionStreamResponse(
|
|
**json_data
|
|
)
|
|
yield chat_completion_response
|
|
except Exception as e:
|
|
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
|
|
|
|
else:
|
|
try:
|
|
error = await response.aread()
|
|
yield json.loads(error)
|
|
except Exception as e:
|
|
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
|
|
|
|
async def get(self, path: str, *args):
|
|
"""
|
|
Get method.
|
|
Args:
|
|
path: str, The path to get.
|
|
args: Any, The arguments to pass to the get method.
|
|
"""
|
|
try:
|
|
response = await self._http_client.get(
|
|
self._api_url + CLIENT_SERVE_PATH + path,
|
|
*args,
|
|
)
|
|
return response
|
|
finally:
|
|
await self._http_client.aclose()
|
|
|
|
async def post(self, path: str, args):
|
|
"""
|
|
Post method.
|
|
Args:
|
|
path: str, The path to post.
|
|
args: Any, The arguments to pass to the post
|
|
"""
|
|
try:
|
|
return await self._http_client.post(
|
|
self._api_url + CLIENT_SERVE_PATH + path,
|
|
json=args,
|
|
)
|
|
finally:
|
|
await self._http_client.aclose()
|
|
|
|
async def post_param(self, path: str, args):
|
|
"""
|
|
Post method.
|
|
Args:
|
|
path: str, The path to post.
|
|
args: Any, The arguments to pass to the post
|
|
"""
|
|
try:
|
|
return await self._http_client.post(
|
|
self._api_url + CLIENT_SERVE_PATH + path,
|
|
params=args,
|
|
)
|
|
finally:
|
|
await self._http_client.aclose()
|
|
|
|
async def patch(self, path: str, *args):
|
|
"""
|
|
Patch method.
|
|
Args:
|
|
path: str, The path to patch.
|
|
args: Any, The arguments to pass to the patch.
|
|
"""
|
|
return self._http_client.patch(self._api_url + CLIENT_SERVE_PATH + path, *args)
|
|
|
|
async def put(self, path: str, args):
|
|
"""
|
|
Put method.
|
|
Args:
|
|
path: str, The path to put.
|
|
args: Any, The arguments to pass to the put.
|
|
"""
|
|
try:
|
|
return await self._http_client.put(
|
|
self._api_url + CLIENT_SERVE_PATH + path, json=args
|
|
)
|
|
finally:
|
|
await self._http_client.aclose()
|
|
|
|
async def delete(self, path: str, *args):
|
|
"""
|
|
Delete method.
|
|
Args:
|
|
path: str, The path to delete.
|
|
args: Any, The arguments to pass to the delete.
|
|
"""
|
|
try:
|
|
return await self._http_client.delete(
|
|
self._api_url + CLIENT_SERVE_PATH + path, *args
|
|
)
|
|
finally:
|
|
await self._http_client.aclose()
|
|
|
|
async def head(self, path: str, *args):
|
|
"""
|
|
Head method.
|
|
Args:
|
|
path: str, The path to head.
|
|
args: Any, The arguments to pass to the head
|
|
"""
|
|
return self._http_client.head(self._api_url + path, *args)
|
|
|
|
|
|
def is_valid_url(api_url: Any) -> bool:
|
|
"""
|
|
Check if the given URL is valid.
|
|
Args:
|
|
api_url: Any, The URL to check.
|
|
Returns:
|
|
bool: True if the URL is valid, False otherwise.
|
|
"""
|
|
if not isinstance(api_url, str):
|
|
return False
|
|
parsed = urlparse(api_url)
|
|
return parsed.scheme != "" and parsed.netloc != ""
|