fix:client mypy error

This commit is contained in:
aries_ckt 2024-03-18 22:12:25 +08:00
parent 4413ff682f
commit 7bc5c59a89
12 changed files with 124 additions and 66 deletions

View File

@ -5,6 +5,9 @@ exclude = /tests/
[mypy-dbgpt.app.*]
follow_imports = skip
[mypy-dbgpt.agent.*]
follow_imports = skip
[mypy-dbgpt.serve.*]
follow_imports = skip
@ -80,4 +83,7 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-clickhouse_connect.*]
ignore_missing_imports = True
[mypy-fastchat.protocol.api_protocol]
ignore_missing_imports = True

View File

@ -48,7 +48,7 @@ fmt: setup ## Format Python code
$(VENV_BIN)/blackdoc examples
# TODO: Use flake8 to enforce Python style guide.
# https://flake8.pycqa.org/en/latest/
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/ dbgpt/client/
# TODO: More package checks with flake8.
.PHONY: fmt-check
@ -57,7 +57,7 @@ fmt-check: setup ## Check Python code formatting and style without making change
$(VENV_BIN)/isort --check-only --extend-skip="examples/notebook" examples
$(VENV_BIN)/black --check --extend-exclude="examples/notebook" .
$(VENV_BIN)/blackdoc --check dbgpt examples
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/ dbgpt/client/
.PHONY: pre-commit
pre-commit: fmt-check test test-doc mypy ## Run formatting and unit tests before committing
@ -73,7 +73,7 @@ test-doc: $(VENV)/.testenv ## Run doctests
.PHONY: mypy
mypy: $(VENV)/.testenv ## Run mypy checks
# https://github.com/python/mypy
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/ dbgpt/datasource/
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/ dbgpt/datasource/ dbgpt/client/
# rag depends on core and storage, so we not need to check it again.
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/storage/
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/
@ -107,4 +107,4 @@ upload: package ## Upload the package to PyPI
.PHONY: help
help: ## Display this help screen
@echo "Available commands:"
@grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-18s\033[0m %s\n", $$1, $$2}' | sort
@grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-18s\033[0m %s\n", $$1, $$2}' | sort

View File

@ -0,0 +1 @@
"""This module is the client of the dbgpt package."""

View File

@ -1,8 +1,10 @@
"""App Client API."""
from dbgpt.client.client import Client
async def get_app(client: Client, app_id: str):
"""Get an app.
Args:
client (Client): The dbgpt client.
app_id (str): The app id.
@ -12,6 +14,7 @@ async def get_app(client: Client, app_id: str):
async def list_app(client: Client):
"""List apps.
Args:
client (Client): The dbgpt client.
"""

View File

@ -1,3 +1,4 @@
"""This module contains the client for the DB-GPT API."""
import json
from typing import Any, AsyncGenerator, List, Optional, Union
from urllib.parse import urlparse
@ -16,7 +17,8 @@ class ClientException(Exception):
"""ClientException is raised when an error occurs in the client."""
def __init__(self, status=None, reason=None, http_resp=None):
"""
"""Initialize the ClientException.
Args:
status: Optional[int], the HTTP status code.
reason: Optional[str], the reason for the exception.
@ -35,7 +37,7 @@ class ClientException(Exception):
self.headers = None
def __str__(self):
"""Custom error messages for exception"""
"""Return the error message."""
error_message = "({0})\n" "Reason: {1}\n".format(self.status, self.reason)
if self.headers:
error_message += "HTTP response headers: {0}\n".format(self.headers)
@ -46,19 +48,28 @@ class ClientException(Exception):
return error_message
class Client(object):
"""Client API."""
class Client:
"""The client for the DB-GPT API."""
def __init__(
self,
api_base: Optional[str] = "http://localhost:5000",
api_base: str = "http://localhost:5000",
api_key: Optional[str] = None,
version: Optional[str] = "v2",
version: str = "v2",
timeout: Optional[httpx._types.TimeoutTypes] = 120,
):
"""
"""Create the client.
Args:
api_base: Optional[str], a full URL for the DB-GPT API. Defaults to the http://localhost:5000.
api_key: Optional[str], The dbgpt api key to use for authentication. Defaults to None.
timeout: Optional[httpx._types.TimeoutTypes]: The timeout to use. Defaults to None.
api_base: Optional[str], a full URL for the DB-GPT API.
Defaults to the `http://localhost:5000`.
api_key: Optional[str], The dbgpt api key to use for authentication.
Defaults to None.
timeout: Optional[httpx._types.TimeoutTypes]: The timeout to use.
Defaults to None.
In most cases, pass in a float number to specify the timeout in seconds.
Returns:
None
@ -75,7 +86,7 @@ class Client(object):
client = Client(api_base=DBGPT_API_BASE, api_key=DBGPT_API_KEY)
client.chat(model="chatgpt_proxyllm", messages="Hello?")
"""
if is_valid_url(api_base):
if api_base and is_valid_url(api_base):
self._api_url = api_base.rstrip("/")
else:
raise ValueError(f"api url {api_base} does not exist or is not accessible.")
@ -105,18 +116,24 @@ class Client(object):
) -> ChatCompletionResponse:
"""
Chat Completion.
Args:
model: str, The model name.
messages: Union[str, List[str]], The user input messages.
temperature: Optional[float], What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
max_new_tokens: Optional[int], The maximum number of tokens that can be generated in the chat completion.
temperature: Optional[float], What sampling temperature to use,between 0
and 2. Higher values like 0.8 will make the output more random,
while lower values like 0.2 will make it more focused and deterministic.
max_new_tokens: Optional[int].The maximum number of tokens that can be
generated in the chat completion.
chat_mode: Optional[str], The chat mode.
chat_param: Optional[str], The chat param of chat mode.
conv_uid: Optional[str], The conversation id of the model inference.
user_name: Optional[str], The user name of the model inference.
sys_code: Optional[str], The system code of the model inference.
span_id: Optional[str], The span id of the model inference.
incremental: bool, Used to control whether the content is returned incrementally or in full each time. If this parameter is not provided, the default is full return.
incremental: bool, Used to control whether the content is returned
incrementally or in full each time. If this parameter is not provided,
the default is full return.
enable_vis: bool, Response content whether to output vis label.
Returns:
ChatCompletionResponse: The chat completion response.
@ -173,18 +190,24 @@ class Client(object):
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
"""
Chat Stream Completion.
Args:
model: str, The model name.
messages: Union[str, List[str]], The user input messages.
temperature: Optional[float], What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
max_new_tokens: Optional[int], The maximum number of tokens that can be generated in the chat completion.
temperature: Optional[float], What sampling temperature to use, between 0
and 2.Higher values like 0.8 will make the output more random, while lower
values like 0.2 will make it more focused and deterministic.
max_new_tokens: Optional[int], The maximum number of tokens that can be
generated in the chat completion.
chat_mode: Optional[str], The chat mode.
chat_param: Optional[str], The chat param of chat mode.
conv_uid: Optional[str], The conversation id of the model inference.
user_name: Optional[str], The user name of the model inference.
sys_code: Optional[str], The system code of the model inference.
span_id: Optional[str], The span id of the model inference.
incremental: bool, Used to control whether the content is returned incrementally or in full each time. If this parameter is not provided, the default is full return.
incremental: bool, Used to control whether the content is returned
incrementally or in full each time. If this parameter is not provided,
the default is full return.
enable_vis: bool, Response content whether to output vis label.
Returns:
ChatCompletionStreamResponse: The chat completion response.
@ -240,11 +263,12 @@ class Client(object):
error = await response.aread()
yield json.loads(error)
except Exception as e:
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
async def get(self, path: str, *args):
"""
Get method.
"""Get method.
Args:
path: str, The path to get.
args: Any, The arguments to pass to the get method.
@ -256,11 +280,12 @@ class Client(object):
)
return response
finally:
await self._http_client.aclose()
async def post(self, path: str, args):
"""
Post method.
"""Post method.
Args:
path: str, The path to post.
args: Any, The arguments to pass to the post
@ -274,8 +299,8 @@ class Client(object):
await self._http_client.aclose()
async def post_param(self, path: str, args):
"""
Post method.
"""Post method.
Args:
path: str, The path to post.
args: Any, The arguments to pass to the post
@ -289,8 +314,8 @@ class Client(object):
await self._http_client.aclose()
async def patch(self, path: str, *args):
"""
Patch method.
"""Patch method.
Args:
path: str, The path to patch.
args: Any, The arguments to pass to the patch.
@ -298,8 +323,8 @@ class Client(object):
return self._http_client.patch(self._api_url + CLIENT_SERVE_PATH + path, *args)
async def put(self, path: str, args):
"""
Put method.
"""Put method.
Args:
path: str, The path to put.
args: Any, The arguments to pass to the put.
@ -312,8 +337,8 @@ class Client(object):
await self._http_client.aclose()
async def delete(self, path: str, *args):
"""
Delete method.
"""Delete method.
Args:
path: str, The path to delete.
args: Any, The arguments to pass to the delete.
@ -326,8 +351,8 @@ class Client(object):
await self._http_client.aclose()
async def head(self, path: str, *args):
"""
Head method.
"""Head method.
Args:
path: str, The path to head.
args: Any, The arguments to pass to the head
@ -336,8 +361,8 @@ class Client(object):
def is_valid_url(api_url: Any) -> bool:
"""
Check if the given URL is valid.
"""Check if the given URL is valid.
Args:
api_url: Any, The URL to check.
Returns:

View File

@ -1,9 +1,11 @@
"""this module contains the flow client functions."""
from dbgpt.client.client import Client
from dbgpt.core.awel.flow.flow_factory import FlowPanel
async def create_flow(client: Client, flow: FlowPanel):
"""Create a new flow.
Args:
client (Client): The dbgpt client.
flow (FlowPanel): The flow panel.
@ -13,6 +15,7 @@ async def create_flow(client: Client, flow: FlowPanel):
async def update_flow(client: Client, flow: FlowPanel):
"""Update a flow.
Args:
client (Client): The dbgpt client.
flow (FlowPanel): The flow panel.
@ -23,6 +26,7 @@ async def update_flow(client: Client, flow: FlowPanel):
async def delete_flow(client: Client, flow_id: str):
"""
Delete a flow.
Args:
client (Client): The dbgpt client.
flow_id (str): The flow id.
@ -33,6 +37,7 @@ async def delete_flow(client: Client, flow_id: str):
async def get_flow(client: Client, flow_id: str):
"""
Get a flow.
Args:
client (Client): The dbgpt client.
flow_id (str): The flow id.
@ -43,6 +48,7 @@ async def get_flow(client: Client, flow_id: str):
async def list_flow(client: Client):
"""
List flows.
Args:
client (Client): The dbgpt client.
"""

View File

@ -1,3 +1,4 @@
"""Knowledge API client."""
import json
from dbgpt.client.client import Client
@ -6,6 +7,7 @@ from dbgpt.client.schemas import DocumentModel, SpaceModel, SyncModel
async def create_space(client: Client, app_model: SpaceModel):
"""Create a new space.
Args:
client (Client): The dbgpt client.
app_model (SpaceModel): The app model.
@ -15,6 +17,7 @@ async def create_space(client: Client, app_model: SpaceModel):
async def update_space(client: Client, app_model: SpaceModel):
"""Update a document.
Args:
client (Client): The dbgpt client.
app_model (SpaceModel): The app model.
@ -24,6 +27,7 @@ async def update_space(client: Client, app_model: SpaceModel):
async def delete_space(client: Client, space_id: str):
"""Delete a space.
Args:
client (Client): The dbgpt client.
app_id (str): The app id.
@ -33,6 +37,7 @@ async def delete_space(client: Client, space_id: str):
async def get_space(client: Client, space_id: str):
"""Get a document.
Args:
client (Client): The dbgpt client.
app_id (str): The app id.
@ -42,6 +47,7 @@ async def get_space(client: Client, space_id: str):
async def list_space(client: Client):
"""List apps.
Args:
client (Client): The dbgpt client.
"""
@ -50,6 +56,7 @@ async def list_space(client: Client):
async def create_document(client: Client, doc_model: DocumentModel):
"""Create a new space.
Args:
client (Client): The dbgpt client.
doc_model (SpaceModel): The document model.
@ -59,6 +66,7 @@ async def create_document(client: Client, doc_model: DocumentModel):
async def delete_document(client: Client, document_id: str):
"""Delete a document.
Args:
client (Client): The dbgpt client.
app_id (str): The app id.
@ -68,6 +76,7 @@ async def delete_document(client: Client, document_id: str):
async def get_document(client: Client, document_id: str):
"""Get a document.
Args:
client (Client): The dbgpt client.
app_id (str): The app id.
@ -77,6 +86,7 @@ async def get_document(client: Client, document_id: str):
async def list_document(client: Client):
"""List documents.
Args:
client (Client): The dbgpt client.
"""
@ -84,7 +94,8 @@ async def list_document(client: Client):
async def sync_document(client: Client, sync_model: SyncModel):
"""sync document.
"""Sync document.
Args:
client (Client): The dbgpt client.
"""

View File

@ -1,6 +1,7 @@
"""this module contains the schemas for the dbgpt client."""
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Union
from typing import List, Optional, Union
from fastapi import File, UploadFile
from pydantic import BaseModel, Field
@ -8,6 +9,8 @@ from pydantic import BaseModel, Field
from dbgpt.agent.resource.resource_api import AgentResource
from dbgpt.rag.chunk_manager import ChunkParameters
"""Chat completion request body"""
class ChatCompletionRequestBody(BaseModel):
"""ChatCompletion LLM http request body."""
@ -54,15 +57,20 @@ class ChatCompletionRequestBody(BaseModel):
)
incremental: bool = Field(
default=True,
description="Used to control whether the content is returned incrementally or in full each time. If this parameter is not provided, the default is full return.",
description="Used to control whether the content is returned incrementally "
"or in full each time. "
"If this parameter is not provided, the default is full return.",
)
enable_vis: str = Field(
default=True, description="response content whether to output vis label"
)
"""Chat completion response"""
class ChatMode(Enum):
"""Chat mode"""
"""Chat mode."""
CHAT_NORMAL = "chat_normal"
CHAT_APP = "chat_app"
@ -70,34 +78,30 @@ class ChatMode(Enum):
CHAT_KNOWLEDGE = "chat_knowledge"
class SpaceModel(BaseModel):
"""name: knowledge space name"""
"""vector_type: vector type"""
id: int = Field(None, description="The space id")
name: str = Field(None, description="The space name")
"""vector_type: vector type"""
vector_type: str = Field(None, description="The vector type")
"""desc: description"""
desc: str = Field(None, description="The description")
"""owner: owner"""
owner: str = Field(None, description="The owner")
"""Agent model"""
class AppDetailModel(BaseModel):
app_code: Optional[str] = Field(None, title="app code")
app_name: Optional[str] = Field(None, title="app name")
agent_name: Optional[str] = Field(None, title="agent name")
node_id: Optional[str] = Field(None, title="node id")
resources: Optional[list[AgentResource]] = Field(None, title="resources")
prompt_template: Optional[str] = Field(None, title="prompt template")
llm_strategy: Optional[str] = Field(None, title="llm strategy")
llm_strategy_value: Optional[str] = Field(None, title="llm strategy value")
"""App detail model."""
app_code: Optional[str] = Field(None, description="app code")
app_name: Optional[str] = Field(None, description="app name")
agent_name: Optional[str] = Field(None, description="agent name")
node_id: Optional[str] = Field(None, description="node id")
resources: Optional[list[AgentResource]] = Field(None, description="resources")
prompt_template: Optional[str] = Field(None, description="prompt template")
llm_strategy: Optional[str] = Field(None, description="llm strategy")
llm_strategy_value: Optional[str] = Field(None, description="llm strategy value")
created_at: datetime = datetime.now()
updated_at: datetime = datetime.now()
"""Awel team model"""
class AwelTeamModel(BaseModel):
"""Awel team model."""
dag_id: str = Field(
...,
description="The unique id of dag",
@ -148,6 +152,8 @@ class AwelTeamModel(BaseModel):
class AppModel(BaseModel):
"""App model."""
app_code: Optional[str] = Field(None, title="app code")
app_name: Optional[str] = Field(None, title="app name")
app_describe: Optional[str] = Field(None, title="app describe")
@ -166,6 +172,8 @@ class AppModel(BaseModel):
class SpaceModel(BaseModel):
"""Space model."""
name: str = Field(
default=None,
description="knowledge space name",
@ -185,6 +193,8 @@ class SpaceModel(BaseModel):
class DocumentModel(BaseModel):
"""Document model."""
id: int = Field(None, description="The doc id")
doc_name: str = Field(None, description="doc name")
"""doc_type: document type"""
@ -200,7 +210,7 @@ class DocumentModel(BaseModel):
class SyncModel(BaseModel):
"""Sync model"""
"""Sync model."""
"""doc_id: doc id"""
doc_id: str = Field(None, description="The doc id")
@ -211,6 +221,6 @@ class SyncModel(BaseModel):
"""model_name: model name"""
model_name: Optional[str] = Field(None, description="model name")
"""chunk_parameters: chunk parameters
"""chunk_parameters: chunk parameters
"""
chunk_parameters: ChunkParameters = Field(None, description="chunk parameters")

View File

@ -23,7 +23,6 @@ Client: Simple App CRUD example
async def main():
# initialize client
DBGPT_API_KEY = "dbgpt"
client = Client(api_key=DBGPT_API_KEY)

View File

@ -55,7 +55,6 @@ Client: Simple Chat example
async def main():
# initialize client
DBGPT_API_KEY = "dbgpt"
client = Client(api_key=DBGPT_API_KEY)

View File

@ -36,7 +36,6 @@ Client: Simple Flow CRUD example
async def main():
# initialize client
DBGPT_API_KEY = "dbgpt"
client = Client(api_key=DBGPT_API_KEY)

View File

@ -68,7 +68,6 @@ from dbgpt.client.knowledge import list_space
async def main():
# initialize client
DBGPT_API_KEY = "dbgpt"
client = Client(api_key=DBGPT_API_KEY)