feat: Client support chatdata (#1343)

This commit is contained in:
Aries-ckt
2024-03-28 09:04:28 +08:00
committed by GitHub
parent f144fc3d36
commit dffd235bfb
21 changed files with 1089 additions and 12 deletions

View File

View File

@@ -0,0 +1,193 @@
from functools import cache
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
from dbgpt.serve.datasource.api.schemas import (
DatasourceServeRequest,
DatasourceServeResponse,
)
from dbgpt.serve.datasource.config import SERVE_SERVICE_COMPONENT_NAME
from dbgpt.serve.datasource.service.service import Service
from dbgpt.util import PaginationResult
router = APIRouter()
# Add your API endpoints here
global_system_app: Optional[SystemApp] = None
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key}
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health", dependencies=[Depends(check_api_key)])
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post("/datasources", dependencies=[Depends(check_api_key)])
async def create(
request: DatasourceServeRequest, service: Service = Depends(get_service)
) -> Result:
"""Create a new Space entity
Args:
request (DatasourceServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.create(request))
@router.put("/datasources", dependencies=[Depends(check_api_key)])
async def update(
request: DatasourceServeRequest, service: Service = Depends(get_service)
) -> Result:
"""Update a Space entity
Args:
request (DatasourceServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
@router.delete(
"/datasources/{datasource_id}",
response_model=Result[None],
dependencies=[Depends(check_api_key)],
)
async def delete(
datasource_id: str, service: Service = Depends(get_service)
) -> Result[None]:
"""Delete a Space entity
Args:
request (DatasourceServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.delete(datasource_id))
@router.get(
"/datasources/{datasource_id}",
dependencies=[Depends(check_api_key)],
response_model=Result[List],
)
async def query(
datasource_id: str, service: Service = Depends(get_service)
) -> Result[List[DatasourceServeResponse]]:
"""Query Space entities
Args:
request (DatasourceServeRequest): The request
service (Service): The service
Returns:
List[ServeResponse]: The response
"""
return Result.succ(service.get(datasource_id))
@router.get(
"/datasources",
dependencies=[Depends(check_api_key)],
response_model=Result[PaginationResult[DatasourceServeResponse]],
)
async def query_page(
page: int = Query(default=1, description="current page"),
page_size: int = Query(default=20, description="page size"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[DatasourceServeResponse]]:
"""Query Space entities
Args:
page (int): The page number
page_size (int): The page size
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.list())
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app

View File

@@ -0,0 +1,41 @@
from typing import Optional
from pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP
class DatasourceServeRequest(BaseModel):
"""name: knowledge space name"""
"""vector_type: vector type"""
id: Optional[int] = Field(None, description="The datasource id")
db_type: str = Field(..., description="Database type, e.g. sqlite, mysql, etc.")
db_name: str = Field(..., description="Database name.")
db_path: str = Field("", description="File path for file-based database.")
db_host: str = Field("", description="Database host.")
db_port: int = Field(0, description="Database port.")
db_user: str = Field("", description="Database user.")
db_pwd: str = Field("", description="Database password.")
comment: str = Field("", description="Comment for the database.")
class DatasourceServeResponse(BaseModel):
"""Flow response model"""
"""name: knowledge space name"""
"""vector_type: vector type"""
id: int = Field(None, description="The datasource id")
db_type: str = Field(..., description="Database type, e.g. sqlite, mysql, etc.")
db_name: str = Field(..., description="Database name.")
db_path: str = Field("", description="File path for file-based database.")
db_host: str = Field("", description="Database host.")
db_port: int = Field(0, description="Database port.")
db_user: str = Field("", description="Database user.")
db_pwd: str = Field("", description="Database password.")
comment: str = Field("", description="Comment for the database.")
# TODO define your own fields here
class Config:
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"

View File

@@ -0,0 +1,28 @@
from dataclasses import dataclass, field
from typing import Optional
from dbgpt.serve.core import BaseServeConfig
APP_NAME = "datasource"
SERVE_APP_NAME = "dbgpt_datasource"
SERVE_APP_NAME_HUMP = "dbgpt_datasource"
SERVE_CONFIG_KEY_PREFIX = "dbgpt_datasource"
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
@dataclass
class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)
default_user: Optional[str] = field(
default=None,
metadata={"help": "Default user name for prompt"},
)
default_sys_code: Optional[str] = field(
default=None,
metadata={"help": "Default system code for prompt"},
)

View File

@@ -0,0 +1 @@
# Define your dependencies here

View File

View File

@@ -0,0 +1,60 @@
import logging
from typing import List, Optional, Union
from sqlalchemy import URL
from dbgpt.component import SystemApp
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
SERVE_CONFIG_KEY_PREFIX,
)
logger = logging.getLogger(__name__)
class Serve(BaseServe):
"""Serve component for DB-GPT"""
name = SERVE_APP_NAME
def __init__(
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v2/serve",
api_tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if api_tags is None:
api_tags = [SERVE_APP_NAME_HUMP]
super().__init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._api_tags
)
init_endpoints(self._system_app)
self._app_has_initiated = True
def on_init(self):
"""Called when init the application.
You can do some initialization here. You can't get other components here because they may be not initialized yet
"""
def before_start(self):
"""Called before the start of the application."""
# TODO: Your code here
self._db_manager = self.create_or_get_db_manager()

View File

@@ -0,0 +1,181 @@
import logging
from typing import List, Optional
from fastapi import HTTPException
from dbgpt._private.config import Config
from dbgpt.component import ComponentType, SystemApp
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.datasource.db_conn_info import DBConfig
from dbgpt.datasource.manages.connect_config_db import (
ConnectConfigDao,
ConnectConfigEntity,
)
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.schema import DBType
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.executor_utils import ExecutorFactory
from ..api.schemas import DatasourceServeRequest, DatasourceServeResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
logger = logging.getLogger(__name__)
CFG = Config()
class Service(
BaseService[ConnectConfigEntity, DatasourceServeRequest, DatasourceServeResponse]
):
"""The service class for Flow"""
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(
self,
system_app: SystemApp,
dao: Optional[ConnectConfigDao] = None,
):
self._system_app = None
self._dao: ConnectConfigDao = dao
self._dag_manager: Optional[DAGManager] = None
self._db_summary_client = None
self._vector_connector = None
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or ConnectConfigDao()
self._system_app = system_app
def before_start(self):
"""Execute before the application starts"""
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
self._db_summary_client = DBSummaryClient(self._system_app)
def after_start(self):
"""Execute after the application starts"""
@property
def dao(
self,
) -> BaseDao[ConnectConfigEntity, DatasourceServeRequest, DatasourceServeResponse]:
"""Returns the internal DAO."""
return self._dao
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def create(self, request: DatasourceServeRequest) -> DatasourceServeResponse:
"""Create a new Datasource entity
Args:
request (DatasourceServeRequest): The request
Returns:
DatasourceServeResponse: The response
"""
datasource = self._dao.get_by_names(request.db_name)
if datasource:
raise HTTPException(
status_code=400,
detail=f"datasource name:{request.db_name} already exists",
)
try:
db_type = DBType.of_db_type(request.db_type)
if not db_type:
raise HTTPException(
status_code=400, detail=f"Unsupported Db Type, {request.db_type}"
)
res = self._dao.create(request)
# async embedding
executor = self._system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create() # type: ignore
executor.submit(
self._db_summary_client.db_summary_embedding,
request.db_name,
request.db_type,
)
except Exception as e:
raise ValueError("Add db connect info error!" + str(e))
return res
def update(self, request: DatasourceServeRequest) -> DatasourceServeResponse:
"""Create a new Datasource entity
Args:
request (DatasourceServeRequest): The request
Returns:
DatasourceServeResponse: The response
"""
datasources = self._dao.get_by_names(request.db_name)
if datasources is None:
raise HTTPException(
status_code=400,
detail=f"there is no datasource name:{request.db_name} exists",
)
db_config = DBConfig(**request.dict())
if CFG.local_db_manager.edit_db(db_config):
return DatasourceServeResponse(**db_config.dict())
else:
raise HTTPException(
status_code=400,
detail=f"update datasource name:{request.db_name} failed",
)
def get(self, datasource_id: str) -> Optional[DatasourceServeResponse]:
"""Get a Flow entity
Args:
request (DatasourceServeRequest): The request
Returns:
DatasourceServeResponse: The response
"""
return self._dao.get_one({"id": datasource_id})
def delete(self, datasource_id: str) -> Optional[DatasourceServeResponse]:
"""Delete a Flow entity
Args:
datasource_id (str): The datasource_id
Returns:
DatasourceServeResponse: The data after deletion
"""
db_config = self._dao.get_one({"id": datasource_id})
vector_name = db_config.db_name + "_profile"
vector_store_config = VectorStoreConfig(name=vector_name)
self._vector_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=vector_store_config,
)
self._vector_connector.delete_vector_name(vector_name)
if db_config:
self._dao.delete({"id": datasource_id})
return db_config
def list(self) -> List[DatasourceServeResponse]:
"""List the Flow entities.
Returns:
List[DatasourceServeResponse]: The list of responses
"""
db_list = CFG.local_db_manager.get_db_list()
return [DatasourceServeResponse(**db) for db in db_list]

View File