feat: Client support chatdata (#1343)

This commit is contained in:
Aries-ckt
2024-03-28 09:04:28 +08:00
committed by GitHub
parent f144fc3d36
commit dffd235bfb
21 changed files with 1089 additions and 12 deletions

View File

@@ -45,6 +45,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
# Register serve app
system_app.register(FlowServe)
# ################################ Rag Serve Register Begin ######################################
from dbgpt.serve.rag.serve import (
SERVE_CONFIG_KEY_PREFIX as RAG_SERVE_CONFIG_KEY_PREFIX,
)
@@ -52,4 +54,14 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
# Register serve app
system_app.register(RagServe)
# ################################ Datasource Serve Register Begin ######################################
from dbgpt.serve.datasource.serve import (
SERVE_CONFIG_KEY_PREFIX as DATASOURCE_SERVE_CONFIG_KEY_PREFIX,
)
from dbgpt.serve.datasource.serve import Serve as DatasourceServe
# Register serve app
system_app.register(DatasourceServe)
# ################################ AWEL Flow Serve Register End ########################################

View File

@@ -127,6 +127,7 @@ async def chat_completions(
request.chat_mode is None
or request.chat_mode == ChatMode.CHAT_NORMAL.value
or request.chat_mode == ChatMode.CHAT_KNOWLEDGE.value
or request.chat_mode == ChatMode.CHAT_DATA.value
):
with root_tracer.start_span(
"get_chat_instance", span_type=SpanType.CHAT, metadata=request.dict()
@@ -146,7 +147,7 @@ async def chat_completions(
status_code=400,
detail={
"error": {
"message": "chat mode now only support chat_normal, chat_app, chat_flow, chat_knowledge",
"message": "chat mode now only support chat_normal, chat_app, chat_flow, chat_knowledge, chat_data",
"type": "invalid_request_error",
"param": None,
"code": "invalid_chat_mode",
@@ -169,7 +170,8 @@ async def get_chat_instance(dialogue: ChatCompletionRequestBody = Body()) -> Bas
dialogue.chat_mode, dialogue.user_name, dialogue.sys_code
)
dialogue.conv_uid = conv_vo.conv_uid
if dialogue.chat_mode == "chat_data":
dialogue.chat_mode = ChatScene.ChatWithDbExecute.value()
if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(f"Unsupported Chat Mode,{dialogue.chat_mode}!")
@@ -201,7 +203,7 @@ async def no_stream_wrapper(
"""
with root_tracer.start_span("no_stream_generator"):
response = await chat.nostream_call()
msg = response.replace("\ufffd", "")
msg = response.replace("\ufffd", "").replace(""", '"')
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=msg),

119
dbgpt/client/datasource.py Normal file
View File

@@ -0,0 +1,119 @@
"""this module contains the datasource client functions."""
from typing import List
from dbgpt.core.schema.api import Result
from .client import Client, ClientException
from .schema import DatasourceModel
async def create_datasource(
client: Client, datasource: DatasourceModel
) -> DatasourceModel:
"""Create a new datasource.
Args:
client (Client): The dbgpt client.
datasource (DatasourceModel): The datasource model.
"""
try:
res = await client.get("/datasources", datasource.dict())
result: Result = res.json()
if result["success"]:
return DatasourceModel(**result["data"])
else:
raise ClientException(status=result["err_code"], reason=result)
except Exception as e:
raise ClientException(f"Failed to create datasource: {e}")
async def update_datasource(
client: Client, datasource: DatasourceModel
) -> DatasourceModel:
"""Update a datasource.
Args:
client (Client): The dbgpt client.
datasource (DatasourceModel): The datasource model.
Returns:
DatasourceModel: The datasource model.
Raises:
ClientException: If the request failed.
"""
try:
res = await client.put("/datasources", datasource.dict())
result: Result = res.json()
if result["success"]:
return DatasourceModel(**result["data"])
else:
raise ClientException(status=result["err_code"], reason=result)
except Exception as e:
raise ClientException(f"Failed to update datasource: {e}")
async def delete_datasource(client: Client, datasource_id: str) -> DatasourceModel:
"""
Delete a datasource.
Args:
client (Client): The dbgpt client.
datasource_id (str): The datasource id.
Returns:
DatasourceModel: The datasource model.
Raises:
ClientException: If the request failed.
"""
try:
res = await client.delete("/datasources/" + datasource_id)
result: Result = res.json()
if result["success"]:
return DatasourceModel(**result["data"])
else:
raise ClientException(status=result["err_code"], reason=result)
except Exception as e:
raise ClientException(f"Failed to delete datasource: {e}")
async def get_datasource(client: Client, datasource_id: str) -> DatasourceModel:
"""
Get a datasource.
Args:
client (Client): The dbgpt client.
datasource_id (str): The datasource id.
Returns:
DatasourceModel: The datasource model.
Raises:
ClientException: If the request failed.
"""
try:
res = await client.get("/datasources/" + datasource_id)
result: Result = res.json()
if result["success"]:
return DatasourceModel(**result["data"])
else:
raise ClientException(status=result["err_code"], reason=result)
except Exception as e:
raise ClientException(f"Failed to get datasource: {e}")
async def list_datasource(client: Client) -> List[DatasourceModel]:
"""
List datasources.
Args:
client (Client): The dbgpt client.
Returns:
List[DatasourceModel]: The list of datasource models.
Raises:
ClientException: If the request failed.
"""
try:
res = await client.get("/datasources")
result: Result = res.json()
if result["success"]:
return [DatasourceModel(**datasource) for datasource in result["data"]]
else:
raise ClientException(status=result["err_code"], reason=result)
except Exception as e:
raise ClientException(f"Failed to list datasource: {e}")

View File

@@ -72,6 +72,7 @@ class ChatMode(Enum):
CHAT_APP = "chat_app"
CHAT_AWEL_FLOW = "chat_flow"
CHAT_KNOWLEDGE = "chat_knowledge"
CHAT_DATA = "chat_data"
class AwelTeamModel(BaseModel):
@@ -278,3 +279,17 @@ class SyncModel(BaseModel):
"""chunk_parameters: chunk parameters
"""
chunk_parameters: ChunkParameters = Field(None, description="chunk parameters")
class DatasourceModel(BaseModel):
"""Datasource model."""
id: Optional[int] = Field(None, description="The datasource id")
db_type: str = Field(..., description="Database type, e.g. sqlite, mysql, etc.")
db_name: str = Field(..., description="Database name.")
db_path: str = Field("", description="File path for file-based database.")
db_host: str = Field("", description="Database host.")
db_port: int = Field(0, description="Database port.")
db_user: str = Field("", description="Database user.")
db_pwd: str = Field("", description="Database password.")
comment: str = Field("", description="Comment for the database.")

View File

@@ -1,10 +1,14 @@
"""DB Model for connect_config."""
import logging
from typing import Optional
from typing import Any, Dict, Optional, Union
from sqlalchemy import Column, Index, Integer, String, Text, UniqueConstraint, text
from dbgpt.serve.datasource.api.schemas import (
DatasourceServeRequest,
DatasourceServeResponse,
)
from dbgpt.storage.metadata import BaseDao, Model
logger = logging.getLogger(__name__)
@@ -218,3 +222,62 @@ class ConnectConfigDao(BaseDao):
session.commit()
session.close()
return True
def from_request(
self, request: Union[DatasourceServeRequest, Dict[str, Any]]
) -> ConnectConfigEntity:
"""Convert the request to an entity.
Args:
request (Union[ServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
request.dict() if isinstance(request, DatasourceServeRequest) else request
)
entity = ConnectConfigEntity(**request_dict)
return entity
def to_request(self, entity: ConnectConfigEntity) -> DatasourceServeRequest:
"""Convert the entity to a request.
Args:
entity (T): The entity
Returns:
REQ: The request
"""
return DatasourceServeRequest(
id=entity.id,
db_type=entity.db_type,
db_name=entity.db_name,
db_path=entity.db_path,
db_host=entity.db_host,
db_port=entity.db_port,
db_user=entity.db_user,
db_pwd=entity.db_pwd,
comment=entity.comment,
)
def to_response(self, entity: ConnectConfigEntity) -> DatasourceServeResponse:
"""Convert the entity to a response.
Args:
entity (T): The entity
Returns:
REQ: The request
"""
return DatasourceServeResponse(
id=entity.id,
db_type=entity.db_type,
db_name=entity.db_name,
db_path=entity.db_path,
db_host=entity.db_host,
db_port=entity.db_port,
db_user=entity.db_user,
db_pwd=entity.db_pwd,
comment=entity.comment,
)

View File

View File

@@ -0,0 +1,193 @@
from functools import cache
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
from dbgpt.serve.datasource.api.schemas import (
DatasourceServeRequest,
DatasourceServeResponse,
)
from dbgpt.serve.datasource.config import SERVE_SERVICE_COMPONENT_NAME
from dbgpt.serve.datasource.service.service import Service
from dbgpt.util import PaginationResult
router = APIRouter()
# Add your API endpoints here
global_system_app: Optional[SystemApp] = None
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key}
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health", dependencies=[Depends(check_api_key)])
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post("/datasources", dependencies=[Depends(check_api_key)])
async def create(
request: DatasourceServeRequest, service: Service = Depends(get_service)
) -> Result:
"""Create a new Space entity
Args:
request (DatasourceServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.create(request))
@router.put("/datasources", dependencies=[Depends(check_api_key)])
async def update(
request: DatasourceServeRequest, service: Service = Depends(get_service)
) -> Result:
"""Update a Space entity
Args:
request (DatasourceServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
@router.delete(
"/datasources/{datasource_id}",
response_model=Result[None],
dependencies=[Depends(check_api_key)],
)
async def delete(
datasource_id: str, service: Service = Depends(get_service)
) -> Result[None]:
"""Delete a Space entity
Args:
request (DatasourceServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.delete(datasource_id))
@router.get(
"/datasources/{datasource_id}",
dependencies=[Depends(check_api_key)],
response_model=Result[List],
)
async def query(
datasource_id: str, service: Service = Depends(get_service)
) -> Result[List[DatasourceServeResponse]]:
"""Query Space entities
Args:
request (DatasourceServeRequest): The request
service (Service): The service
Returns:
List[ServeResponse]: The response
"""
return Result.succ(service.get(datasource_id))
@router.get(
"/datasources",
dependencies=[Depends(check_api_key)],
response_model=Result[PaginationResult[DatasourceServeResponse]],
)
async def query_page(
page: int = Query(default=1, description="current page"),
page_size: int = Query(default=20, description="page size"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[DatasourceServeResponse]]:
"""Query Space entities
Args:
page (int): The page number
page_size (int): The page size
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.list())
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app

View File

@@ -0,0 +1,41 @@
from typing import Optional
from pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP
class DatasourceServeRequest(BaseModel):
"""name: knowledge space name"""
"""vector_type: vector type"""
id: Optional[int] = Field(None, description="The datasource id")
db_type: str = Field(..., description="Database type, e.g. sqlite, mysql, etc.")
db_name: str = Field(..., description="Database name.")
db_path: str = Field("", description="File path for file-based database.")
db_host: str = Field("", description="Database host.")
db_port: int = Field(0, description="Database port.")
db_user: str = Field("", description="Database user.")
db_pwd: str = Field("", description="Database password.")
comment: str = Field("", description="Comment for the database.")
class DatasourceServeResponse(BaseModel):
"""Flow response model"""
"""name: knowledge space name"""
"""vector_type: vector type"""
id: int = Field(None, description="The datasource id")
db_type: str = Field(..., description="Database type, e.g. sqlite, mysql, etc.")
db_name: str = Field(..., description="Database name.")
db_path: str = Field("", description="File path for file-based database.")
db_host: str = Field("", description="Database host.")
db_port: int = Field(0, description="Database port.")
db_user: str = Field("", description="Database user.")
db_pwd: str = Field("", description="Database password.")
comment: str = Field("", description="Comment for the database.")
# TODO define your own fields here
class Config:
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"

View File

@@ -0,0 +1,28 @@
from dataclasses import dataclass, field
from typing import Optional
from dbgpt.serve.core import BaseServeConfig
APP_NAME = "datasource"
SERVE_APP_NAME = "dbgpt_datasource"
SERVE_APP_NAME_HUMP = "dbgpt_datasource"
SERVE_CONFIG_KEY_PREFIX = "dbgpt_datasource"
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
@dataclass
class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)
default_user: Optional[str] = field(
default=None,
metadata={"help": "Default user name for prompt"},
)
default_sys_code: Optional[str] = field(
default=None,
metadata={"help": "Default system code for prompt"},
)

View File

@@ -0,0 +1 @@
# Define your dependencies here

View File

View File

@@ -0,0 +1,60 @@
import logging
from typing import List, Optional, Union
from sqlalchemy import URL
from dbgpt.component import SystemApp
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
SERVE_CONFIG_KEY_PREFIX,
)
logger = logging.getLogger(__name__)
class Serve(BaseServe):
"""Serve component for DB-GPT"""
name = SERVE_APP_NAME
def __init__(
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v2/serve",
api_tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if api_tags is None:
api_tags = [SERVE_APP_NAME_HUMP]
super().__init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._api_tags
)
init_endpoints(self._system_app)
self._app_has_initiated = True
def on_init(self):
"""Called when init the application.
You can do some initialization here. You can't get other components here because they may be not initialized yet
"""
def before_start(self):
"""Called before the start of the application."""
# TODO: Your code here
self._db_manager = self.create_or_get_db_manager()

View File

@@ -0,0 +1,181 @@
import logging
from typing import List, Optional
from fastapi import HTTPException
from dbgpt._private.config import Config
from dbgpt.component import ComponentType, SystemApp
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.datasource.db_conn_info import DBConfig
from dbgpt.datasource.manages.connect_config_db import (
ConnectConfigDao,
ConnectConfigEntity,
)
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.schema import DBType
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.executor_utils import ExecutorFactory
from ..api.schemas import DatasourceServeRequest, DatasourceServeResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
logger = logging.getLogger(__name__)
CFG = Config()
class Service(
BaseService[ConnectConfigEntity, DatasourceServeRequest, DatasourceServeResponse]
):
"""The service class for Flow"""
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(
self,
system_app: SystemApp,
dao: Optional[ConnectConfigDao] = None,
):
self._system_app = None
self._dao: ConnectConfigDao = dao
self._dag_manager: Optional[DAGManager] = None
self._db_summary_client = None
self._vector_connector = None
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or ConnectConfigDao()
self._system_app = system_app
def before_start(self):
"""Execute before the application starts"""
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
self._db_summary_client = DBSummaryClient(self._system_app)
def after_start(self):
"""Execute after the application starts"""
@property
def dao(
self,
) -> BaseDao[ConnectConfigEntity, DatasourceServeRequest, DatasourceServeResponse]:
"""Returns the internal DAO."""
return self._dao
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def create(self, request: DatasourceServeRequest) -> DatasourceServeResponse:
"""Create a new Datasource entity
Args:
request (DatasourceServeRequest): The request
Returns:
DatasourceServeResponse: The response
"""
datasource = self._dao.get_by_names(request.db_name)
if datasource:
raise HTTPException(
status_code=400,
detail=f"datasource name:{request.db_name} already exists",
)
try:
db_type = DBType.of_db_type(request.db_type)
if not db_type:
raise HTTPException(
status_code=400, detail=f"Unsupported Db Type, {request.db_type}"
)
res = self._dao.create(request)
# async embedding
executor = self._system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create() # type: ignore
executor.submit(
self._db_summary_client.db_summary_embedding,
request.db_name,
request.db_type,
)
except Exception as e:
raise ValueError("Add db connect info error!" + str(e))
return res
def update(self, request: DatasourceServeRequest) -> DatasourceServeResponse:
"""Create a new Datasource entity
Args:
request (DatasourceServeRequest): The request
Returns:
DatasourceServeResponse: The response
"""
datasources = self._dao.get_by_names(request.db_name)
if datasources is None:
raise HTTPException(
status_code=400,
detail=f"there is no datasource name:{request.db_name} exists",
)
db_config = DBConfig(**request.dict())
if CFG.local_db_manager.edit_db(db_config):
return DatasourceServeResponse(**db_config.dict())
else:
raise HTTPException(
status_code=400,
detail=f"update datasource name:{request.db_name} failed",
)
def get(self, datasource_id: str) -> Optional[DatasourceServeResponse]:
"""Get a Flow entity
Args:
request (DatasourceServeRequest): The request
Returns:
DatasourceServeResponse: The response
"""
return self._dao.get_one({"id": datasource_id})
def delete(self, datasource_id: str) -> Optional[DatasourceServeResponse]:
"""Delete a Flow entity
Args:
datasource_id (str): The datasource_id
Returns:
DatasourceServeResponse: The data after deletion
"""
db_config = self._dao.get_one({"id": datasource_id})
vector_name = db_config.db_name + "_profile"
vector_store_config = VectorStoreConfig(name=vector_name)
self._vector_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=vector_store_config,
)
self._vector_connector.delete_vector_name(vector_name)
if db_config:
self._dao.delete({"id": datasource_id})
return db_config
def list(self) -> List[DatasourceServeResponse]:
"""List the Flow entities.
Returns:
List[DatasourceServeResponse]: The list of responses
"""
db_list = CFG.local_db_manager.get_db_list()
return [DatasourceServeResponse(**db) for db in db_list]

View File