"""This module contains the client for the DB-GPT API.""" import atexit import json import os from typing import Any, AsyncGenerator, Dict, List, Optional, Union from urllib.parse import urlparse import httpx from dbgpt._private.pydantic import model_to_dict from dbgpt.core.schema.api import ChatCompletionResponse, ChatCompletionStreamResponse from .schema import ChatCompletionRequestBody CLIENT_API_PATH = "api" CLIENT_SERVE_PATH = "serve" class ClientException(Exception): """ClientException is raised when an error occurs in the client.""" def __init__(self, status=None, reason=None, http_resp=None): """Initialize the ClientException. Args: status: Optional[int], the HTTP status code. reason: Optional[str], the reason for the exception. http_resp: Optional[httpx.Response], the HTTP response object. """ self.status = status self.reason = reason self.http_resp = http_resp self.headers = http_resp.headers if http_resp else None self.body = http_resp.text if http_resp else None def __str__(self): """Return the error message.""" error_message = "({0})\n" "Reason: {1}\n".format(self.status, self.reason) if self.headers: error_message += "HTTP response headers: {0}\n".format(self.headers) if self.body: error_message += "HTTP response body: {0}\n".format(self.body) return error_message """Client API.""" class Client: """The client for the DB-GPT API.""" def __init__( self, api_base: Optional[str] = None, api_key: Optional[str] = None, version: str = "v2", timeout: Optional[httpx._types.TimeoutTypes] = 120, ): """Create the client. Args: api_base: Optional[str], a full URL for the DB-GPT API. Defaults to the `http://localhost:5670/api/v2`. api_key: Optional[str], The dbgpt api key to use for authentication. Defaults to None. timeout: Optional[httpx._types.TimeoutTypes]: The timeout to use. Defaults to None. In most cases, pass in a float number to specify the timeout in seconds. Returns: None Raise: ClientException Examples: -------- .. code-block:: python from dbgpt.client import Client DBGPT_API_BASE = "http://localhost:5670/api/v2" DBGPT_API_KEY = "dbgpt" client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY) client.chat(model="chatgpt_proxyllm", messages="Hello?") """ if not api_base: api_base = os.getenv( "DBGPT_API_BASE", f"http://localhost:5670/{CLIENT_API_PATH}/{version}" ) if not api_key: api_key = os.getenv("DBGPT_API_KEY") if api_base and is_valid_url(api_base): self._api_url = api_base else: raise ValueError(f"api url {api_base} does not exist or is not accessible.") self._api_key = api_key self._version = version self._timeout = timeout headers = {"Authorization": f"Bearer {self._api_key}"} if self._api_key else {} self._http_client = httpx.AsyncClient( headers=headers, timeout=timeout if timeout else httpx.Timeout(None) ) atexit.register(self.close) def _base_url(self): parsed_url = urlparse(self._api_url) host = parsed_url.hostname scheme = parsed_url.scheme port = parsed_url.port if port: return f"{scheme}://{host}:{port}" return f"{scheme}://{host}" async def chat( self, model: str, messages: Union[str, List[str]], temperature: Optional[float] = None, max_new_tokens: Optional[int] = None, chat_mode: Optional[str] = None, chat_param: Optional[str] = None, conv_uid: Optional[str] = None, user_name: Optional[str] = None, sys_code: Optional[str] = None, span_id: Optional[str] = None, incremental: bool = True, enable_vis: bool = True, **kwargs, ) -> ChatCompletionResponse: """ Chat Completion. Args: model: str, The model name. messages: Union[str, List[str]], The user input messages. temperature: Optional[float], What sampling temperature to use,between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. max_new_tokens: Optional[int].The maximum number of tokens that can be generated in the chat completion. chat_mode: Optional[str], The chat mode. chat_param: Optional[str], The chat param of chat mode. conv_uid: Optional[str], The conversation id of the model inference. user_name: Optional[str], The user name of the model inference. sys_code: Optional[str], The system code of the model inference. span_id: Optional[str], The span id of the model inference. incremental: bool, Used to control whether the content is returned incrementally or in full each time. If this parameter is not provided, the default is full return. enable_vis: bool, Response content whether to output vis label. Returns: ChatCompletionResponse: The chat completion response. Examples: -------- .. code-block:: python from dbgpt.client import Client DBGPT_API_BASE = "http://localhost:5670/api/v2" DBGPT_API_KEY = "dbgpt" client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY) res = await client.chat(model="chatgpt_proxyllm", messages="Hello?") """ request = ChatCompletionRequestBody( model=model, messages=messages, stream=False, temperature=temperature, max_new_tokens=max_new_tokens, chat_mode=chat_mode, chat_param=chat_param, conv_uid=conv_uid, user_name=user_name, sys_code=sys_code, span_id=span_id, incremental=incremental, enable_vis=enable_vis, ) response = await self._http_client.post( self._api_url + "/chat/completions", json=model_to_dict(request) ) if response.status_code == 200: json_data = json.loads(response.text) chat_completion_response = ChatCompletionResponse(**json_data) return chat_completion_response else: return json.loads(response.content) async def chat_stream( self, model: str, messages: Union[str, List[str]], temperature: Optional[float] = None, max_new_tokens: Optional[int] = None, chat_mode: Optional[str] = None, chat_param: Optional[str] = None, conv_uid: Optional[str] = None, user_name: Optional[str] = None, sys_code: Optional[str] = None, span_id: Optional[str] = None, incremental: bool = True, enable_vis: bool = True, **kwargs, ) -> AsyncGenerator[ChatCompletionStreamResponse, None]: """ Chat Stream Completion. Args: model: str, The model name. messages: Union[str, List[str]], The user input messages. temperature: Optional[float], What sampling temperature to use, between 0 and 2.Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. max_new_tokens: Optional[int], The maximum number of tokens that can be generated in the chat completion. chat_mode: Optional[str], The chat mode. chat_param: Optional[str], The chat param of chat mode. conv_uid: Optional[str], The conversation id of the model inference. user_name: Optional[str], The user name of the model inference. sys_code: Optional[str], The system code of the model inference. span_id: Optional[str], The span id of the model inference. incremental: bool, Used to control whether the content is returned incrementally or in full each time. If this parameter is not provided, the default is full return. enable_vis: bool, Response content whether to output vis label. Returns: ChatCompletionStreamResponse: The chat completion response. Examples: -------- .. code-block:: python from dbgpt.client import Client DBGPT_API_BASE = "http://localhost:5670/api/v2" DBGPT_API_KEY = "dbgpt" client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY) res = await client.chat_stream(model="chatgpt_proxyllm", messages="Hello?") """ request = ChatCompletionRequestBody( model=model, messages=messages, stream=True, temperature=temperature, max_new_tokens=max_new_tokens, chat_mode=chat_mode, chat_param=chat_param, conv_uid=conv_uid, user_name=user_name, sys_code=sys_code, span_id=span_id, incremental=incremental, enable_vis=enable_vis, ) async for chat_completion_response in self._chat_stream(model_to_dict(request)): yield chat_completion_response async def _chat_stream( self, data: Dict[str, Any] ) -> AsyncGenerator[ChatCompletionStreamResponse, None]: """Chat Stream Completion. Args: data: dict, The data to send to the API. Returns: AsyncGenerator[dict, None]: The chat completion response. """ async with self._http_client.stream( method="POST", url=self._api_url + "/chat/completions", json=data, headers={}, ) as response: if response.status_code == 200: sse_data = "" async for line in response.aiter_lines(): try: if line.strip() == "data: [DONE]": break if line.startswith("data:"): if line.startswith("data: "): sse_data = line[len("data: ") :] else: sse_data = line[len("data:") :] json_data = json.loads(sse_data) chat_completion_response = ChatCompletionStreamResponse( **json_data ) yield chat_completion_response except Exception as e: raise Exception( f"Failed to parse SSE data: {e}, sse_data: {sse_data}" ) else: try: error = await response.aread() yield json.loads(error) except Exception as e: raise e async def get(self, path: str, *args, **kwargs): """Get method. Args: path: str, The path to get. args: Any, The arguments to pass to the get method. """ kwargs = {k: v for k, v in kwargs.items() if v is not None} response = await self._http_client.get( f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args, params=kwargs, ) return response async def post(self, path: str, args): """Post method. Args: path: str, The path to post. args: Any, The arguments to pass to the post """ return await self._http_client.post( f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", json=args, ) async def post_param(self, path: str, args): """Post method. Args: path: str, The path to post. args: Any, The arguments to pass to the post """ return await self._http_client.post( f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", params=args, ) async def patch(self, path: str, *args): """Patch method. Args: path: str, The path to patch. args: Any, The arguments to pass to the patch. """ return self._http_client.patch( f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args ) async def put(self, path: str, args): """Put method. Args: path: str, The path to put. args: Any, The arguments to pass to the put. """ return await self._http_client.put( f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", json=args ) async def delete(self, path: str, *args): """Delete method. Args: path: str, The path to delete. args: Any, The arguments to pass to delete. """ return await self._http_client.delete( f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args ) async def head(self, path: str, *args): """Head method. Args: path: str, The path to head. args: Any, The arguments to pass to the head """ return self._http_client.head(self._api_url + path, *args) def close(self): """Close the client.""" from dbgpt.util import get_or_create_event_loop if not self._http_client.is_closed: loop = get_or_create_event_loop() loop.run_until_complete(self._http_client.aclose()) async def aclose(self): """Close the client.""" await self._http_client.aclose() def is_valid_url(api_url: Any) -> bool: """Check if the given URL is valid. Args: api_url: Any, The URL to check. Returns: bool: True if the URL is valid, False otherwise. """ if not isinstance(api_url, str): return False parsed = urlparse(api_url) return parsed.scheme != "" and parsed.netloc != ""