DB-GPT/dbgpt/client/client.py
2024-07-18 17:50:40 +08:00

409 lines
14 KiB
Python

"""This module contains the client for the DB-GPT API."""
import atexit
import json
import os
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from urllib.parse import urlparse
import httpx
from dbgpt._private.pydantic import model_to_dict
from dbgpt.core.schema.api import ChatCompletionResponse, ChatCompletionStreamResponse
from .schema import ChatCompletionRequestBody
CLIENT_API_PATH = "api"
CLIENT_SERVE_PATH = "serve"
class ClientException(Exception):
"""ClientException is raised when an error occurs in the client."""
def __init__(self, status=None, reason=None, http_resp=None):
"""Initialize the ClientException.
Args:
status: Optional[int], the HTTP status code.
reason: Optional[str], the reason for the exception.
http_resp: Optional[httpx.Response], the HTTP response object.
"""
self.status = status
self.reason = reason
self.http_resp = http_resp
self.headers = http_resp.headers if http_resp else None
self.body = http_resp.text if http_resp else None
def __str__(self):
"""Return the error message."""
error_message = "({0})\n" "Reason: {1}\n".format(self.status, self.reason)
if self.headers:
error_message += "HTTP response headers: {0}\n".format(self.headers)
if self.body:
error_message += "HTTP response body: {0}\n".format(self.body)
return error_message
"""Client API."""
class Client:
"""The client for the DB-GPT API."""
def __init__(
self,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
version: str = "v2",
timeout: Optional[httpx._types.TimeoutTypes] = 120,
):
"""Create the client.
Args:
api_base: Optional[str], a full URL for the DB-GPT API.
Defaults to the `http://localhost:5670/api/v2`.
api_key: Optional[str], The dbgpt api key to use for authentication.
Defaults to None.
timeout: Optional[httpx._types.TimeoutTypes]: The timeout to use.
Defaults to None.
In most cases, pass in a float number to specify the timeout in seconds.
Returns:
None
Raise: ClientException
Examples:
--------
.. code-block:: python
from dbgpt.client import Client
DBGPT_API_BASE = "http://localhost:5670/api/v2"
DBGPT_API_KEY = "dbgpt"
client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY)
client.chat(model="chatgpt_proxyllm", messages="Hello?")
"""
if not api_base:
api_base = os.getenv(
"DBGPT_API_BASE", f"http://localhost:5670/{CLIENT_API_PATH}/{version}"
)
if not api_key:
api_key = os.getenv("DBGPT_API_KEY")
if api_base and is_valid_url(api_base):
self._api_url = api_base
else:
raise ValueError(f"api url {api_base} does not exist or is not accessible.")
self._api_key = api_key
self._version = version
self._timeout = timeout
headers = {"Authorization": f"Bearer {self._api_key}"} if self._api_key else {}
self._http_client = httpx.AsyncClient(
headers=headers, timeout=timeout if timeout else httpx.Timeout(None)
)
atexit.register(self.close)
def _base_url(self):
parsed_url = urlparse(self._api_url)
host = parsed_url.hostname
scheme = parsed_url.scheme
port = parsed_url.port
if port:
return f"{scheme}://{host}:{port}"
return f"{scheme}://{host}"
async def chat(
self,
model: str,
messages: Union[str, List[str]],
temperature: Optional[float] = None,
max_new_tokens: Optional[int] = None,
chat_mode: Optional[str] = None,
chat_param: Optional[str] = None,
conv_uid: Optional[str] = None,
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
span_id: Optional[str] = None,
incremental: bool = True,
enable_vis: bool = True,
**kwargs,
) -> ChatCompletionResponse:
"""
Chat Completion.
Args:
model: str, The model name.
messages: Union[str, List[str]], The user input messages.
temperature: Optional[float], What sampling temperature to use,between 0
and 2. Higher values like 0.8 will make the output more random,
while lower values like 0.2 will make it more focused and deterministic.
max_new_tokens: Optional[int].The maximum number of tokens that can be
generated in the chat completion.
chat_mode: Optional[str], The chat mode.
chat_param: Optional[str], The chat param of chat mode.
conv_uid: Optional[str], The conversation id of the model inference.
user_name: Optional[str], The user name of the model inference.
sys_code: Optional[str], The system code of the model inference.
span_id: Optional[str], The span id of the model inference.
incremental: bool, Used to control whether the content is returned
incrementally or in full each time. If this parameter is not provided,
the default is full return.
enable_vis: bool, Response content whether to output vis label.
Returns:
ChatCompletionResponse: The chat completion response.
Examples:
--------
.. code-block:: python
from dbgpt.client import Client
DBGPT_API_BASE = "http://localhost:5670/api/v2"
DBGPT_API_KEY = "dbgpt"
client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY)
res = await client.chat(model="chatgpt_proxyllm", messages="Hello?")
"""
request = ChatCompletionRequestBody(
model=model,
messages=messages,
stream=False,
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=chat_mode,
chat_param=chat_param,
conv_uid=conv_uid,
user_name=user_name,
sys_code=sys_code,
span_id=span_id,
incremental=incremental,
enable_vis=enable_vis,
)
response = await self._http_client.post(
self._api_url + "/chat/completions", json=model_to_dict(request)
)
if response.status_code == 200:
json_data = json.loads(response.text)
chat_completion_response = ChatCompletionResponse(**json_data)
return chat_completion_response
else:
return json.loads(response.content)
async def chat_stream(
self,
model: str,
messages: Union[str, List[str]],
temperature: Optional[float] = None,
max_new_tokens: Optional[int] = None,
chat_mode: Optional[str] = None,
chat_param: Optional[str] = None,
conv_uid: Optional[str] = None,
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
span_id: Optional[str] = None,
incremental: bool = True,
enable_vis: bool = True,
**kwargs,
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
"""
Chat Stream Completion.
Args:
model: str, The model name.
messages: Union[str, List[str]], The user input messages.
temperature: Optional[float], What sampling temperature to use, between 0
and 2.Higher values like 0.8 will make the output more random, while lower
values like 0.2 will make it more focused and deterministic.
max_new_tokens: Optional[int], The maximum number of tokens that can be
generated in the chat completion.
chat_mode: Optional[str], The chat mode.
chat_param: Optional[str], The chat param of chat mode.
conv_uid: Optional[str], The conversation id of the model inference.
user_name: Optional[str], The user name of the model inference.
sys_code: Optional[str], The system code of the model inference.
span_id: Optional[str], The span id of the model inference.
incremental: bool, Used to control whether the content is returned
incrementally or in full each time. If this parameter is not provided,
the default is full return.
enable_vis: bool, Response content whether to output vis label.
Returns:
ChatCompletionStreamResponse: The chat completion response.
Examples:
--------
.. code-block:: python
from dbgpt.client import Client
DBGPT_API_BASE = "http://localhost:5670/api/v2"
DBGPT_API_KEY = "dbgpt"
client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY)
res = await client.chat_stream(model="chatgpt_proxyllm", messages="Hello?")
"""
request = ChatCompletionRequestBody(
model=model,
messages=messages,
stream=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
chat_mode=chat_mode,
chat_param=chat_param,
conv_uid=conv_uid,
user_name=user_name,
sys_code=sys_code,
span_id=span_id,
incremental=incremental,
enable_vis=enable_vis,
)
async for chat_completion_response in self._chat_stream(model_to_dict(request)):
yield chat_completion_response
async def _chat_stream(
self, data: Dict[str, Any]
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
"""Chat Stream Completion.
Args:
data: dict, The data to send to the API.
Returns:
AsyncGenerator[dict, None]: The chat completion response.
"""
async with self._http_client.stream(
method="POST",
url=self._api_url + "/chat/completions",
json=data,
headers={},
) as response:
if response.status_code == 200:
sse_data = ""
async for line in response.aiter_lines():
try:
if line.strip() == "data: [DONE]":
break
if line.startswith("data:"):
if line.startswith("data: "):
sse_data = line[len("data: ") :]
else:
sse_data = line[len("data:") :]
json_data = json.loads(sse_data)
chat_completion_response = ChatCompletionStreamResponse(
**json_data
)
yield chat_completion_response
except Exception as e:
raise Exception(
f"Failed to parse SSE data: {e}, sse_data: {sse_data}"
)
else:
try:
error = await response.aread()
yield json.loads(error)
except Exception as e:
raise e
async def get(self, path: str, *args, **kwargs):
"""Get method.
Args:
path: str, The path to get.
args: Any, The arguments to pass to the get method.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
response = await self._http_client.get(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
*args,
params=kwargs,
)
return response
async def post(self, path: str, args):
"""Post method.
Args:
path: str, The path to post.
args: Any, The arguments to pass to the post
"""
return await self._http_client.post(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
json=args,
)
async def post_param(self, path: str, args):
"""Post method.
Args:
path: str, The path to post.
args: Any, The arguments to pass to the post
"""
return await self._http_client.post(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
params=args,
)
async def patch(self, path: str, *args):
"""Patch method.
Args:
path: str, The path to patch.
args: Any, The arguments to pass to the patch.
"""
return self._http_client.patch(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args
)
async def put(self, path: str, args):
"""Put method.
Args:
path: str, The path to put.
args: Any, The arguments to pass to the put.
"""
return await self._http_client.put(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", json=args
)
async def delete(self, path: str, *args):
"""Delete method.
Args:
path: str, The path to delete.
args: Any, The arguments to pass to delete.
"""
return await self._http_client.delete(
f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args
)
async def head(self, path: str, *args):
"""Head method.
Args:
path: str, The path to head.
args: Any, The arguments to pass to the head
"""
return self._http_client.head(self._api_url + path, *args)
def close(self):
"""Close the client."""
from dbgpt.util import get_or_create_event_loop
if not self._http_client.is_closed:
loop = get_or_create_event_loop()
loop.run_until_complete(self._http_client.aclose())
async def aclose(self):
"""Close the client."""
await self._http_client.aclose()
def is_valid_url(api_url: Any) -> bool:
"""Check if the given URL is valid.
Args:
api_url: Any, The URL to check.
Returns:
bool: True if the URL is valid, False otherwise.
"""
if not isinstance(api_url, str):
return False
parsed = urlparse(api_url)
return parsed.scheme != "" and parsed.netloc != ""