mirror of
				https://github.com/csunny/DB-GPT.git
				synced 2025-10-24 19:08:58 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			409 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			409 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """This module contains the client for the DB-GPT API."""
 | |
| 
 | |
| import atexit
 | |
| import json
 | |
| import os
 | |
| from typing import Any, AsyncGenerator, Dict, List, Optional, Union
 | |
| from urllib.parse import urlparse
 | |
| 
 | |
| import httpx
 | |
| 
 | |
| from dbgpt._private.pydantic import model_to_dict
 | |
| from dbgpt.core.schema.api import ChatCompletionResponse, ChatCompletionStreamResponse
 | |
| 
 | |
| from .schema import ChatCompletionRequestBody
 | |
| 
 | |
| CLIENT_API_PATH = "api"
 | |
| CLIENT_SERVE_PATH = "serve"
 | |
| 
 | |
| 
 | |
| class ClientException(Exception):
 | |
|     """ClientException is raised when an error occurs in the client."""
 | |
| 
 | |
|     def __init__(self, status=None, reason=None, http_resp=None):
 | |
|         """Initialize the ClientException.
 | |
| 
 | |
|         Args:
 | |
|             status: Optional[int], the HTTP status code.
 | |
|             reason: Optional[str], the reason for the exception.
 | |
|             http_resp: Optional[httpx.Response], the HTTP response object.
 | |
|         """
 | |
|         self.status = status
 | |
|         self.reason = reason
 | |
|         self.http_resp = http_resp
 | |
|         self.headers = http_resp.headers if http_resp else None
 | |
|         self.body = http_resp.text if http_resp else None
 | |
| 
 | |
|     def __str__(self):
 | |
|         """Return the error message."""
 | |
|         error_message = "({0})\n" "Reason: {1}\n".format(self.status, self.reason)
 | |
|         if self.headers:
 | |
|             error_message += "HTTP response headers: {0}\n".format(self.headers)
 | |
| 
 | |
|         if self.body:
 | |
|             error_message += "HTTP response body: {0}\n".format(self.body)
 | |
| 
 | |
|         return error_message
 | |
| 
 | |
| 
 | |
| """Client API."""
 | |
| 
 | |
| 
 | |
| class Client:
 | |
|     """The client for the DB-GPT API."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         api_base: Optional[str] = None,
 | |
|         api_key: Optional[str] = None,
 | |
|         version: str = "v2",
 | |
|         timeout: Optional[httpx._types.TimeoutTypes] = 120,
 | |
|     ):
 | |
|         """Create the client.
 | |
| 
 | |
|         Args:
 | |
|             api_base: Optional[str], a full URL for the DB-GPT API.
 | |
|                 Defaults to the `http://localhost:5670/api/v2`.
 | |
|             api_key: Optional[str], The dbgpt api key to use for authentication.
 | |
|                 Defaults to None.
 | |
|             timeout: Optional[httpx._types.TimeoutTypes]: The timeout to use.
 | |
|                 Defaults to None.
 | |
|             In most cases, pass in a float number to specify the timeout in seconds.
 | |
|         Returns:
 | |
|             None
 | |
|         Raise: ClientException
 | |
| 
 | |
|         Examples:
 | |
|         --------
 | |
|         .. code-block:: python
 | |
| 
 | |
|             from dbgpt.client import Client
 | |
| 
 | |
|             DBGPT_API_BASE = "http://localhost:5670/api/v2"
 | |
|             DBGPT_API_KEY = "dbgpt"
 | |
|             client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY)
 | |
|             client.chat(model="chatgpt_proxyllm", messages="Hello?")
 | |
|         """
 | |
|         if not api_base:
 | |
|             api_base = os.getenv(
 | |
|                 "DBGPT_API_BASE", f"http://localhost:5670/{CLIENT_API_PATH}/{version}"
 | |
|             )
 | |
|         if not api_key:
 | |
|             api_key = os.getenv("DBGPT_API_KEY")
 | |
|         if api_base and is_valid_url(api_base):
 | |
|             self._api_url = api_base
 | |
|         else:
 | |
|             raise ValueError(f"api url {api_base} does not exist or is not accessible.")
 | |
|         self._api_key = api_key
 | |
|         self._version = version
 | |
|         self._timeout = timeout
 | |
|         headers = {"Authorization": f"Bearer {self._api_key}"} if self._api_key else {}
 | |
|         self._http_client = httpx.AsyncClient(
 | |
|             headers=headers, timeout=timeout if timeout else httpx.Timeout(None)
 | |
|         )
 | |
|         atexit.register(self.close)
 | |
| 
 | |
|     def _base_url(self):
 | |
|         parsed_url = urlparse(self._api_url)
 | |
|         host = parsed_url.hostname
 | |
|         scheme = parsed_url.scheme
 | |
|         port = parsed_url.port
 | |
|         if port:
 | |
|             return f"{scheme}://{host}:{port}"
 | |
|         return f"{scheme}://{host}"
 | |
| 
 | |
|     async def chat(
 | |
|         self,
 | |
|         model: str,
 | |
|         messages: Union[str, List[str]],
 | |
|         temperature: Optional[float] = None,
 | |
|         max_new_tokens: Optional[int] = None,
 | |
|         chat_mode: Optional[str] = None,
 | |
|         chat_param: Optional[str] = None,
 | |
|         conv_uid: Optional[str] = None,
 | |
|         user_name: Optional[str] = None,
 | |
|         sys_code: Optional[str] = None,
 | |
|         span_id: Optional[str] = None,
 | |
|         incremental: bool = True,
 | |
|         enable_vis: bool = True,
 | |
|         **kwargs,
 | |
|     ) -> ChatCompletionResponse:
 | |
|         """
 | |
|         Chat Completion.
 | |
| 
 | |
|         Args:
 | |
|             model: str, The model name.
 | |
|             messages: Union[str, List[str]], The user input messages.
 | |
|             temperature: Optional[float], What sampling temperature to use,between 0
 | |
|                 and 2. Higher values like 0.8 will make the output more random,
 | |
|                 while lower values like 0.2 will make it more focused and deterministic.
 | |
|             max_new_tokens: Optional[int].The maximum number of tokens that can be
 | |
|                 generated in the chat completion.
 | |
|             chat_mode: Optional[str], The chat mode.
 | |
|             chat_param: Optional[str], The chat param of chat mode.
 | |
|             conv_uid: Optional[str], The conversation id of the model inference.
 | |
|             user_name: Optional[str], The user name of the model inference.
 | |
|             sys_code: Optional[str], The system code of the model inference.
 | |
|             span_id: Optional[str], The span id of the model inference.
 | |
|             incremental: bool, Used to control whether the content is returned
 | |
|                 incrementally or in full each time. If this parameter is not provided,
 | |
|                 the default is full return.
 | |
|             enable_vis: bool, Response content whether to output vis label.
 | |
|         Returns:
 | |
|             ChatCompletionResponse: The chat completion response.
 | |
|         Examples:
 | |
|         --------
 | |
|         .. code-block:: python
 | |
| 
 | |
|             from dbgpt.client import Client
 | |
| 
 | |
|             DBGPT_API_BASE = "http://localhost:5670/api/v2"
 | |
|             DBGPT_API_KEY = "dbgpt"
 | |
|             client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY)
 | |
|             res = await client.chat(model="chatgpt_proxyllm", messages="Hello?")
 | |
|         """
 | |
|         request = ChatCompletionRequestBody(
 | |
|             model=model,
 | |
|             messages=messages,
 | |
|             stream=False,
 | |
|             temperature=temperature,
 | |
|             max_new_tokens=max_new_tokens,
 | |
|             chat_mode=chat_mode,
 | |
|             chat_param=chat_param,
 | |
|             conv_uid=conv_uid,
 | |
|             user_name=user_name,
 | |
|             sys_code=sys_code,
 | |
|             span_id=span_id,
 | |
|             incremental=incremental,
 | |
|             enable_vis=enable_vis,
 | |
|         )
 | |
|         response = await self._http_client.post(
 | |
|             self._api_url + "/chat/completions", json=model_to_dict(request)
 | |
|         )
 | |
|         if response.status_code == 200:
 | |
|             json_data = json.loads(response.text)
 | |
|             chat_completion_response = ChatCompletionResponse(**json_data)
 | |
|             return chat_completion_response
 | |
|         else:
 | |
|             return json.loads(response.content)
 | |
| 
 | |
|     async def chat_stream(
 | |
|         self,
 | |
|         model: str,
 | |
|         messages: Union[str, List[str]],
 | |
|         temperature: Optional[float] = None,
 | |
|         max_new_tokens: Optional[int] = None,
 | |
|         chat_mode: Optional[str] = None,
 | |
|         chat_param: Optional[str] = None,
 | |
|         conv_uid: Optional[str] = None,
 | |
|         user_name: Optional[str] = None,
 | |
|         sys_code: Optional[str] = None,
 | |
|         span_id: Optional[str] = None,
 | |
|         incremental: bool = True,
 | |
|         enable_vis: bool = True,
 | |
|         **kwargs,
 | |
|     ) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
 | |
|         """
 | |
|         Chat Stream Completion.
 | |
| 
 | |
|         Args:
 | |
|             model: str, The model name.
 | |
|             messages: Union[str, List[str]], The user input messages.
 | |
|             temperature: Optional[float], What sampling temperature to use, between 0
 | |
|             and 2.Higher values like 0.8 will make the output more random, while lower
 | |
|                 values like 0.2 will make it more focused and deterministic.
 | |
|             max_new_tokens: Optional[int], The maximum number of tokens that can be
 | |
|             generated in the chat completion.
 | |
|             chat_mode: Optional[str], The chat mode.
 | |
|             chat_param: Optional[str], The chat param of chat mode.
 | |
|             conv_uid: Optional[str], The conversation id of the model inference.
 | |
|             user_name: Optional[str], The user name of the model inference.
 | |
|             sys_code: Optional[str], The system code of the model inference.
 | |
|             span_id: Optional[str], The span id of the model inference.
 | |
|             incremental: bool, Used to control whether the content is returned
 | |
|                 incrementally or in full each time. If this parameter is not provided,
 | |
|                 the default is full return.
 | |
|             enable_vis: bool, Response content whether to output vis label.
 | |
|         Returns:
 | |
|             ChatCompletionStreamResponse: The chat completion response.
 | |
| 
 | |
|         Examples:
 | |
|         --------
 | |
|         .. code-block:: python
 | |
| 
 | |
|             from dbgpt.client import Client
 | |
| 
 | |
|             DBGPT_API_BASE = "http://localhost:5670/api/v2"
 | |
|             DBGPT_API_KEY = "dbgpt"
 | |
|             client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY)
 | |
|             res = await client.chat_stream(model="chatgpt_proxyllm", messages="Hello?")
 | |
|         """
 | |
|         request = ChatCompletionRequestBody(
 | |
|             model=model,
 | |
|             messages=messages,
 | |
|             stream=True,
 | |
|             temperature=temperature,
 | |
|             max_new_tokens=max_new_tokens,
 | |
|             chat_mode=chat_mode,
 | |
|             chat_param=chat_param,
 | |
|             conv_uid=conv_uid,
 | |
|             user_name=user_name,
 | |
|             sys_code=sys_code,
 | |
|             span_id=span_id,
 | |
|             incremental=incremental,
 | |
|             enable_vis=enable_vis,
 | |
|         )
 | |
|         async for chat_completion_response in self._chat_stream(model_to_dict(request)):
 | |
|             yield chat_completion_response
 | |
| 
 | |
|     async def _chat_stream(
 | |
|         self, data: Dict[str, Any]
 | |
|     ) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
 | |
|         """Chat Stream Completion.
 | |
| 
 | |
|         Args:
 | |
|             data: dict, The data to send to the API.
 | |
|         Returns:
 | |
|             AsyncGenerator[dict, None]: The chat completion response.
 | |
|         """
 | |
|         async with self._http_client.stream(
 | |
|             method="POST",
 | |
|             url=self._api_url + "/chat/completions",
 | |
|             json=data,
 | |
|             headers={},
 | |
|         ) as response:
 | |
|             if response.status_code == 200:
 | |
|                 sse_data = ""
 | |
|                 async for line in response.aiter_lines():
 | |
|                     try:
 | |
|                         if line.strip() == "data: [DONE]":
 | |
|                             break
 | |
|                         if line.startswith("data:"):
 | |
|                             if line.startswith("data: "):
 | |
|                                 sse_data = line[len("data: ") :]
 | |
|                             else:
 | |
|                                 sse_data = line[len("data:") :]
 | |
|                             json_data = json.loads(sse_data)
 | |
|                             chat_completion_response = ChatCompletionStreamResponse(
 | |
|                                 **json_data
 | |
|                             )
 | |
|                             yield chat_completion_response
 | |
|                     except Exception as e:
 | |
|                         raise Exception(
 | |
|                             f"Failed to parse SSE data: {e}, sse_data: {sse_data}"
 | |
|                         )
 | |
| 
 | |
|             else:
 | |
|                 try:
 | |
|                     error = await response.aread()
 | |
|                     yield json.loads(error)
 | |
|                 except Exception as e:
 | |
|                     raise e
 | |
| 
 | |
|     async def get(self, path: str, *args, **kwargs):
 | |
|         """Get method.
 | |
| 
 | |
|         Args:
 | |
|             path: str, The path to get.
 | |
|             args: Any, The arguments to pass to the get method.
 | |
|         """
 | |
|         kwargs = {k: v for k, v in kwargs.items() if v is not None}
 | |
|         response = await self._http_client.get(
 | |
|             f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
 | |
|             *args,
 | |
|             params=kwargs,
 | |
|         )
 | |
|         return response
 | |
| 
 | |
|     async def post(self, path: str, args):
 | |
|         """Post method.
 | |
| 
 | |
|         Args:
 | |
|             path: str, The path to post.
 | |
|             args: Any, The arguments to pass to the post
 | |
|         """
 | |
|         return await self._http_client.post(
 | |
|             f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
 | |
|             json=args,
 | |
|         )
 | |
| 
 | |
|     async def post_param(self, path: str, args):
 | |
|         """Post method.
 | |
| 
 | |
|         Args:
 | |
|             path: str, The path to post.
 | |
|             args: Any, The arguments to pass to the post
 | |
|         """
 | |
|         return await self._http_client.post(
 | |
|             f"{self._api_url}/{CLIENT_SERVE_PATH}{path}",
 | |
|             params=args,
 | |
|         )
 | |
| 
 | |
|     async def patch(self, path: str, *args):
 | |
|         """Patch method.
 | |
| 
 | |
|         Args:
 | |
|             path: str, The path to patch.
 | |
|             args: Any, The arguments to pass to the patch.
 | |
|         """
 | |
|         return self._http_client.patch(
 | |
|             f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args
 | |
|         )
 | |
| 
 | |
|     async def put(self, path: str, args):
 | |
|         """Put method.
 | |
| 
 | |
|         Args:
 | |
|             path: str, The path to put.
 | |
|             args: Any, The arguments to pass to the put.
 | |
|         """
 | |
|         return await self._http_client.put(
 | |
|             f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", json=args
 | |
|         )
 | |
| 
 | |
|     async def delete(self, path: str, *args):
 | |
|         """Delete method.
 | |
| 
 | |
|         Args:
 | |
|             path: str, The path to delete.
 | |
|             args: Any, The arguments to pass to delete.
 | |
|         """
 | |
|         return await self._http_client.delete(
 | |
|             f"{self._api_url}/{CLIENT_SERVE_PATH}{path}", *args
 | |
|         )
 | |
| 
 | |
|     async def head(self, path: str, *args):
 | |
|         """Head method.
 | |
| 
 | |
|         Args:
 | |
|             path: str, The path to head.
 | |
|             args: Any, The arguments to pass to the head
 | |
|         """
 | |
|         return self._http_client.head(self._api_url + path, *args)
 | |
| 
 | |
|     def close(self):
 | |
|         """Close the client."""
 | |
|         from dbgpt.util import get_or_create_event_loop
 | |
| 
 | |
|         if not self._http_client.is_closed:
 | |
|             loop = get_or_create_event_loop()
 | |
|             loop.run_until_complete(self._http_client.aclose())
 | |
| 
 | |
|     async def aclose(self):
 | |
|         """Close the client."""
 | |
|         await self._http_client.aclose()
 | |
| 
 | |
| 
 | |
| def is_valid_url(api_url: Any) -> bool:
 | |
|     """Check if the given URL is valid.
 | |
| 
 | |
|     Args:
 | |
|         api_url: Any, The URL to check.
 | |
|     Returns:
 | |
|         bool: True if the URL is valid, False otherwise.
 | |
|     """
 | |
|     if not isinstance(api_url, str):
 | |
|         return False
 | |
|     parsed = urlparse(api_url)
 | |
|     return parsed.scheme != "" and parsed.netloc != ""
 |