feat: Client support chatdata (#1343)

This commit is contained in:
Aries-ckt
2024-03-28 09:04:28 +08:00
committed by GitHub
parent f144fc3d36
commit dffd235bfb
21 changed files with 1089 additions and 12 deletions

View File

View File

@@ -0,0 +1,193 @@
from functools import cache
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
from dbgpt.serve.datasource.api.schemas import (
DatasourceServeRequest,
DatasourceServeResponse,
)
from dbgpt.serve.datasource.config import SERVE_SERVICE_COMPONENT_NAME
from dbgpt.serve.datasource.service.service import Service
from dbgpt.util import PaginationResult
router = APIRouter()
# Add your API endpoints here
global_system_app: Optional[SystemApp] = None
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key}
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health", dependencies=[Depends(check_api_key)])
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post("/datasources", dependencies=[Depends(check_api_key)])
async def create(
request: DatasourceServeRequest, service: Service = Depends(get_service)
) -> Result:
"""Create a new Space entity
Args:
request (DatasourceServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.create(request))
@router.put("/datasources", dependencies=[Depends(check_api_key)])
async def update(
request: DatasourceServeRequest, service: Service = Depends(get_service)
) -> Result:
"""Update a Space entity
Args:
request (DatasourceServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
@router.delete(
"/datasources/{datasource_id}",
response_model=Result[None],
dependencies=[Depends(check_api_key)],
)
async def delete(
datasource_id: str, service: Service = Depends(get_service)
) -> Result[None]:
"""Delete a Space entity
Args:
request (DatasourceServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.delete(datasource_id))
@router.get(
"/datasources/{datasource_id}",
dependencies=[Depends(check_api_key)],
response_model=Result[List],
)
async def query(
datasource_id: str, service: Service = Depends(get_service)
) -> Result[List[DatasourceServeResponse]]:
"""Query Space entities
Args:
request (DatasourceServeRequest): The request
service (Service): The service
Returns:
List[ServeResponse]: The response
"""
return Result.succ(service.get(datasource_id))
@router.get(
"/datasources",
dependencies=[Depends(check_api_key)],
response_model=Result[PaginationResult[DatasourceServeResponse]],
)
async def query_page(
page: int = Query(default=1, description="current page"),
page_size: int = Query(default=20, description="page size"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[DatasourceServeResponse]]:
"""Query Space entities
Args:
page (int): The page number
page_size (int): The page size
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.list())
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app

View File

@@ -0,0 +1,41 @@
from typing import Optional
from pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP
class DatasourceServeRequest(BaseModel):
"""name: knowledge space name"""
"""vector_type: vector type"""
id: Optional[int] = Field(None, description="The datasource id")
db_type: str = Field(..., description="Database type, e.g. sqlite, mysql, etc.")
db_name: str = Field(..., description="Database name.")
db_path: str = Field("", description="File path for file-based database.")
db_host: str = Field("", description="Database host.")
db_port: int = Field(0, description="Database port.")
db_user: str = Field("", description="Database user.")
db_pwd: str = Field("", description="Database password.")
comment: str = Field("", description="Comment for the database.")
class DatasourceServeResponse(BaseModel):
"""Flow response model"""
"""name: knowledge space name"""
"""vector_type: vector type"""
id: int = Field(None, description="The datasource id")
db_type: str = Field(..., description="Database type, e.g. sqlite, mysql, etc.")
db_name: str = Field(..., description="Database name.")
db_path: str = Field("", description="File path for file-based database.")
db_host: str = Field("", description="Database host.")
db_port: int = Field(0, description="Database port.")
db_user: str = Field("", description="Database user.")
db_pwd: str = Field("", description="Database password.")
comment: str = Field("", description="Comment for the database.")
# TODO define your own fields here
class Config:
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"