mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 21:21:08 +00:00
feat: Add dbgpt client and add api v2
This commit is contained in:
65
dbgpt/serve/agent/app/endpoints.py
Normal file
65
dbgpt/serve/agent/app/endpoints.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
|
||||
from dbgpt.serve.agent.db.gpts_app import (
|
||||
GptsApp,
|
||||
GptsAppCollectionDao,
|
||||
GptsAppDao,
|
||||
GptsAppQuery,
|
||||
)
|
||||
from dbgpt.serve.core import Result
|
||||
|
||||
router = APIRouter()
|
||||
gpts_dao = GptsAppDao()
|
||||
collection_dao = GptsAppCollectionDao()
|
||||
|
||||
|
||||
@router.get("/v2/serve/apps")
|
||||
async def app_list(
|
||||
user_name: Optional[str] = Query(default=None, description="user name"),
|
||||
sys_code: Optional[str] = Query(default=None, description="system code"),
|
||||
is_collected: Optional[str] = Query(default=None, description="system code"),
|
||||
page: int = Query(default=1, description="current page"),
|
||||
page_size: int = Query(default=20, description="page size"),
|
||||
):
|
||||
try:
|
||||
query = GptsAppQuery(
|
||||
page_no=page, page_size=page_size, is_collected=is_collected
|
||||
)
|
||||
return Result.succ(gpts_dao.app_list(query, True))
|
||||
except Exception as ex:
|
||||
return Result.failed(err_code="E000X", msg=f"query app error: {ex}")
|
||||
|
||||
|
||||
@router.get("/v2/serve/apps/{app_id}")
|
||||
async def app_detail(app_id: str):
|
||||
try:
|
||||
return Result.succ(gpts_dao.app_detail(app_id))
|
||||
except Exception as ex:
|
||||
return Result.failed(err_code="E000X", msg=f"query app error: {ex}")
|
||||
|
||||
|
||||
@router.put("/v2/serve/apps/{app_id}")
|
||||
async def app_update(app_id: str, gpts_app: GptsApp):
|
||||
try:
|
||||
return Result.succ(gpts_dao.edit(gpts_app))
|
||||
except Exception as ex:
|
||||
return Result.failed(err_code="E000X", msg=f"edit app error: {ex}")
|
||||
|
||||
|
||||
@router.post("/v2/serve/apps")
|
||||
async def app_create(gpts_app: GptsApp):
|
||||
try:
|
||||
return Result.succ(gpts_dao.create(gpts_app))
|
||||
except Exception as ex:
|
||||
return Result.failed(err_code="E000X", msg=f"edit app error: {ex}")
|
||||
|
||||
|
||||
@router.delete("/v2/serve/apps/{app_id}")
|
||||
async def app_delete(app_id: str, user_code: Optional[str], sys_code: Optional[str]):
|
||||
try:
|
||||
gpts_dao.delete(app_id, user_code, sys_code)
|
||||
return Result.succ([])
|
||||
except Exception as ex:
|
||||
return Result.failed(err_code="E000X", msg=f"delete app error: {ex}")
|
@@ -2,7 +2,7 @@ import uuid
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
@@ -45,6 +45,7 @@ def _parse_api_keys(api_keys: str) -> List[str]:
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
request: Request = None,
|
||||
service: Service = Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
@@ -63,6 +64,9 @@ async def check_api_key(
|
||||
assert res.status_code == 200
|
||||
|
||||
"""
|
||||
if request.url.path.startswith(f"/api/v1"):
|
||||
return None
|
||||
|
||||
if service.config.api_keys:
|
||||
api_keys = _parse_api_keys(service.config.api_keys)
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
|
@@ -16,7 +16,16 @@ class BaseServeConfig(BaseParameters):
|
||||
config (AppConfig): Application configuration
|
||||
config_prefix (str): Configuration prefix
|
||||
"""
|
||||
global_prefix = "dbgpt.app.global."
|
||||
global_dict = config.get_all_by_prefix(global_prefix)
|
||||
config_dict = config.get_all_by_prefix(config_prefix)
|
||||
# remove prefix
|
||||
config_dict = {k[len(config_prefix) :]: v for k, v in config_dict.items()}
|
||||
config_dict = {
|
||||
k[len(config_prefix) :]: v
|
||||
for k, v in config_dict.items()
|
||||
if k.startswith(config_prefix)
|
||||
}
|
||||
for k, v in global_dict.items():
|
||||
if k not in config_dict and k[len(global_prefix) :] in cls().__dict__:
|
||||
config_dict[k[len(global_prefix) :]] = v
|
||||
return cls(**config_dict)
|
||||
|
@@ -69,11 +69,11 @@ async def validation_exception_handler(
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
res = Result.failed(
|
||||
msg=exc.detail,
|
||||
err_code="E0002",
|
||||
msg=str(exc.detail),
|
||||
err_code=str(exc.status_code),
|
||||
)
|
||||
logger.error(f"http_exception_handler catch HTTPException: {res}")
|
||||
return JSONResponse(status_code=400, content=res.dict())
|
||||
return JSONResponse(status_code=exc.status_code, content=res.dict())
|
||||
|
||||
|
||||
async def common_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
|
@@ -18,7 +18,7 @@ class BaseServe(BaseComponent, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
system_app: SystemApp,
|
||||
api_prefix: str,
|
||||
api_prefix: str | List[str],
|
||||
api_tags: List[str],
|
||||
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
||||
try_create_tables: Optional[bool] = False,
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from functools import cache
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
@@ -9,7 +9,7 @@ from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
|
||||
from dbgpt.serve.core import Result
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..service.service import Service
|
||||
from .schemas import ServeRequest, ServerResponse
|
||||
|
||||
@@ -45,6 +45,7 @@ def _parse_api_keys(api_keys: str) -> List[str]:
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
request: Request = None,
|
||||
service: Service = Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
@@ -63,6 +64,10 @@ async def check_api_key(
|
||||
assert res.status_code == 200
|
||||
|
||||
"""
|
||||
if request.url.path.startswith(f"/api/v1"):
|
||||
return None
|
||||
|
||||
# for api_version in serve.serve_versions():
|
||||
if service.config.api_keys:
|
||||
api_keys = _parse_api_keys(service.config.api_keys)
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
|
@@ -27,11 +27,13 @@ class Serve(BaseServe):
|
||||
def __init__(
|
||||
self,
|
||||
system_app: SystemApp,
|
||||
api_prefix: Optional[str] = f"/api/v1/serve/awel",
|
||||
api_prefix: Optional[List[str]] = None,
|
||||
api_tags: Optional[List[str]] = None,
|
||||
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
||||
try_create_tables: Optional[bool] = False,
|
||||
):
|
||||
if api_prefix is None:
|
||||
api_prefix = [f"/api/v1/serve/awel", "/api/v2/serve/awel"]
|
||||
if api_tags is None:
|
||||
api_tags = [SERVE_APP_NAME_HUMP]
|
||||
super().__init__(
|
||||
@@ -43,9 +45,10 @@ class Serve(BaseServe):
|
||||
if self._app_has_initiated:
|
||||
return
|
||||
self._system_app = system_app
|
||||
self._system_app.app.include_router(
|
||||
router, prefix=self._api_prefix, tags=self._api_tags
|
||||
)
|
||||
for prefix in self._api_prefix:
|
||||
self._system_app.app.include_router(
|
||||
router, prefix=prefix, tags=self._api_tags
|
||||
)
|
||||
init_endpoints(self._system_app)
|
||||
self._app_has_initiated = True
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
@@ -44,6 +44,7 @@ def _parse_api_keys(api_keys: str) -> List[str]:
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
request: Request = None,
|
||||
service: Service = Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
@@ -62,6 +63,9 @@ async def check_api_key(
|
||||
assert res.status_code == 200
|
||||
|
||||
"""
|
||||
if request.url.path.startswith(f"/api/v1"):
|
||||
return None
|
||||
|
||||
if service.config.api_keys:
|
||||
api_keys = _parse_api_keys(service.config.api_keys)
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
|
300
dbgpt/serve/rag/api/endpoints.py
Normal file
300
dbgpt/serve/rag/api/endpoints.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import Result
|
||||
from dbgpt.serve.rag.api.schemas import (
|
||||
DocumentServeRequest,
|
||||
DocumentServeResponse,
|
||||
KnowledgeSyncRequest,
|
||||
SpaceServeRequest,
|
||||
SpaceServeResponse,
|
||||
)
|
||||
from dbgpt.serve.rag.config import SERVE_SERVICE_COMPONENT_NAME
|
||||
from dbgpt.serve.rag.service.service import Service
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Add your API endpoints here
|
||||
|
||||
global_system_app: Optional[SystemApp] = None
|
||||
|
||||
|
||||
def get_service() -> Service:
|
||||
"""Get the service instance"""
|
||||
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
|
||||
|
||||
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
@cache
|
||||
def _parse_api_keys(api_keys: str) -> List[str]:
|
||||
"""Parse the string api keys to a list
|
||||
|
||||
Args:
|
||||
api_keys (str): The string api keys
|
||||
|
||||
Returns:
|
||||
List[str]: The list of api keys
|
||||
"""
|
||||
if not api_keys:
|
||||
return []
|
||||
return [key.strip() for key in api_keys.split(",")]
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
|
||||
If the api key is not set, allow all.
|
||||
|
||||
Your can pass the token in you request header like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import requests
|
||||
|
||||
client_api_key = "your_api_key"
|
||||
headers = {"Authorization": "Bearer " + client_api_key}
|
||||
res = requests.get("http://test/hello", headers=headers)
|
||||
assert res.status_code == 200
|
||||
|
||||
"""
|
||||
if service.config.api_keys:
|
||||
api_keys = _parse_api_keys(service.config.api_keys)
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/health", dependencies=[Depends(check_api_key)])
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
|
||||
async def test_auth():
|
||||
"""Test auth endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/spaces", dependencies=[Depends(check_api_key)])
|
||||
async def create(
|
||||
request: SpaceServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result:
|
||||
"""Create a new Space entity
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.create_space(request))
|
||||
|
||||
|
||||
@router.put("/spaces", dependencies=[Depends(check_api_key)])
|
||||
async def update(
|
||||
request: SpaceServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result:
|
||||
"""Update a Space entity
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.update_space(request))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/spaces/{space_id}",
|
||||
response_model=Result[None],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def delete(
|
||||
space_id: str, service: Service = Depends(get_service)
|
||||
) -> Result[None]:
|
||||
"""Delete a Space entity
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.delete(space_id))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/spaces/{space_id}",
|
||||
dependencies=[Depends(check_api_key)],
|
||||
response_model=Result[List],
|
||||
)
|
||||
async def query(
|
||||
space_id: str, service: Service = Depends(get_service)
|
||||
) -> Result[List[SpaceServeResponse]]:
|
||||
"""Query Space entities
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
List[ServeResponse]: The response
|
||||
"""
|
||||
request = {"id": space_id}
|
||||
return Result.succ(service.get(request))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/spaces",
|
||||
dependencies=[Depends(check_api_key)],
|
||||
response_model=Result[PaginationResult[SpaceServeResponse]],
|
||||
)
|
||||
async def query_page(
|
||||
page: int = Query(default=1, description="current page"),
|
||||
page_size: int = Query(default=20, description="page size"),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result[PaginationResult[SpaceServeResponse]]:
|
||||
"""Query Space entities
|
||||
|
||||
Args:
|
||||
page (int): The page number
|
||||
page_size (int): The page size
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.get_list_by_page({}, page, page_size))
|
||||
|
||||
|
||||
@router.post("/documents", dependencies=[Depends(check_api_key)])
|
||||
async def create_document(
|
||||
doc_name: str = Form(...),
|
||||
doc_type: str = Form(...),
|
||||
space_id: str = Form(...),
|
||||
content: Optional[str] = Form(None),
|
||||
doc_file: Optional[UploadFile] = File(None),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result:
|
||||
"""Create a new Document entity
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
request = DocumentServeRequest(
|
||||
doc_name=doc_name,
|
||||
doc_type=doc_type,
|
||||
content=content,
|
||||
doc_file=doc_file,
|
||||
space_id=space_id,
|
||||
)
|
||||
return Result.succ(await service.create_document(request))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents/{document_id}",
|
||||
dependencies=[Depends(check_api_key)],
|
||||
response_model=Result[List],
|
||||
)
|
||||
async def query(
|
||||
document_id: str, service: Service = Depends(get_service)
|
||||
) -> Result[List[SpaceServeResponse]]:
|
||||
"""Query Space entities
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
List[ServeResponse]: The response
|
||||
"""
|
||||
request = {"id": document_id}
|
||||
return Result.succ(service.get_document(request))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents",
|
||||
dependencies=[Depends(check_api_key)],
|
||||
response_model=Result[PaginationResult[SpaceServeResponse]],
|
||||
)
|
||||
async def query_page(
|
||||
page: int = Query(default=1, description="current page"),
|
||||
page_size: int = Query(default=20, description="page size"),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result[PaginationResult[DocumentServeResponse]]:
|
||||
"""Query Space entities
|
||||
|
||||
Args:
|
||||
page (int): The page number
|
||||
page_size (int): The page size
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.get_document_list({}, page, page_size))
|
||||
|
||||
|
||||
@router.post("/documents/sync", dependencies=[Depends(check_api_key)])
|
||||
async def sync_documents(
|
||||
requests: List[KnowledgeSyncRequest], service: Service = Depends(get_service)
|
||||
) -> Result:
|
||||
"""Create a new Document entity
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.sync_document(requests))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/documents/{document_id}",
|
||||
dependencies=[Depends(check_api_key)],
|
||||
response_model=Result[None],
|
||||
)
|
||||
async def delete_document(
|
||||
document_id: str, service: Service = Depends(get_service)
|
||||
) -> Result[None]:
|
||||
"""Delete a Space entity
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.delete_document(document_id))
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
global global_system_app
|
||||
system_app.register(Service)
|
||||
global_system_app = system_app
|
93
dbgpt/serve/rag/api/schemas.py
Normal file
93
dbgpt/serve/rag/api/schemas.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import File, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
|
||||
|
||||
class SpaceServeRequest(BaseModel):
|
||||
"""name: knowledge space name"""
|
||||
|
||||
"""vector_type: vector type"""
|
||||
id: Optional[int] = Field(None, description="The space id")
|
||||
name: str = Field(None, description="The space name")
|
||||
"""vector_type: vector type"""
|
||||
vector_type: str = Field(None, description="The vector type")
|
||||
"""desc: description"""
|
||||
desc: str = Field(None, description="The description")
|
||||
"""owner: owner"""
|
||||
owner: str = Field(None, description="The owner")
|
||||
|
||||
|
||||
class DocumentServeRequest(BaseModel):
|
||||
id: int = Field(None, description="The doc id")
|
||||
doc_name: str = Field(None, description="doc name")
|
||||
"""doc_type: document type"""
|
||||
doc_type: str = Field(None, description="The doc type")
|
||||
"""content: description"""
|
||||
content: str = Field(None, description="content")
|
||||
"""doc file"""
|
||||
doc_file: UploadFile = File(...)
|
||||
"""doc_source: doc source"""
|
||||
doc_source: str = None
|
||||
"""doc_source: doc source"""
|
||||
space_id: str = None
|
||||
|
||||
|
||||
class DocumentServeResponse(BaseModel):
|
||||
id: int = Field(None, description="The doc id")
|
||||
doc_name: str = Field(None, description="doc type")
|
||||
"""vector_type: vector type"""
|
||||
doc_type: str = Field(None, description="The doc content")
|
||||
"""desc: description"""
|
||||
content: str = Field(None, description="content")
|
||||
"""vector ids"""
|
||||
vector_ids: str = Field(None, description="vector ids")
|
||||
"""doc_source: doc source"""
|
||||
doc_source: str = None
|
||||
"""doc_source: doc source"""
|
||||
space: str = None
|
||||
|
||||
|
||||
class KnowledgeSyncRequest(BaseModel):
|
||||
"""Sync request"""
|
||||
|
||||
"""doc_ids: doc ids"""
|
||||
doc_id: int = Field(None, description="The doc id")
|
||||
|
||||
"""space id"""
|
||||
space_id: str = Field(None, description="space id")
|
||||
|
||||
"""model_name: model name"""
|
||||
model_name: Optional[str] = Field(None, description="model name")
|
||||
|
||||
"""chunk_parameters: chunk parameters
|
||||
"""
|
||||
chunk_parameters: ChunkParameters = Field(None, description="chunk parameters")
|
||||
|
||||
|
||||
class SpaceServeResponse(BaseModel):
|
||||
"""Flow response model"""
|
||||
|
||||
"""name: knowledge space name"""
|
||||
|
||||
"""vector_type: vector type"""
|
||||
id: int = Field(None, description="The space id")
|
||||
name: str = Field(None, description="The space name")
|
||||
"""vector_type: vector type"""
|
||||
vector_type: str = Field(None, description="The vector type")
|
||||
"""desc: description"""
|
||||
desc: str = Field(None, description="The description")
|
||||
"""context: argument context"""
|
||||
context: str = Field(None, description="The context")
|
||||
"""owner: owner"""
|
||||
owner: str = Field(None, description="The owner")
|
||||
"""sys code"""
|
||||
sys_code: str = Field(None, description="The sys code")
|
||||
|
||||
# TODO define your own fields here
|
||||
class Config:
|
||||
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"
|
28
dbgpt/serve/rag/config.py
Normal file
28
dbgpt/serve/rag/config.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.serve.core import BaseServeConfig
|
||||
|
||||
APP_NAME = "rag"
|
||||
SERVE_APP_NAME = "dbgpt_rag"
|
||||
SERVE_APP_NAME_HUMP = "dbgpt_rag"
|
||||
SERVE_CONFIG_KEY_PREFIX = "dbgpt_rag"
|
||||
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServeConfig(BaseServeConfig):
|
||||
"""Parameters for the serve command"""
|
||||
|
||||
api_keys: Optional[str] = field(
|
||||
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
|
||||
)
|
||||
|
||||
default_user: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Default user name for prompt"},
|
||||
)
|
||||
default_sys_code: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Default system code for prompt"},
|
||||
)
|
1
dbgpt/serve/rag/dependencies.py
Normal file
1
dbgpt/serve/rag/dependencies.py
Normal file
@@ -0,0 +1 @@
|
||||
# Define your dependencies here
|
62
dbgpt/serve/rag/serve.py
Normal file
62
dbgpt/serve/rag/serve.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sqlalchemy import URL
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import BaseServe
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
|
||||
from .api.endpoints import init_endpoints, router
|
||||
from .config import (
|
||||
APP_NAME,
|
||||
SERVE_APP_NAME,
|
||||
SERVE_APP_NAME_HUMP,
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Serve(BaseServe):
|
||||
"""Serve component for DB-GPT"""
|
||||
|
||||
name = SERVE_APP_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_app: SystemApp,
|
||||
api_prefix: Optional[str] = f"/api/v2/serve/knowledge",
|
||||
api_tags: Optional[List[str]] = None,
|
||||
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
||||
try_create_tables: Optional[bool] = False,
|
||||
):
|
||||
if api_tags is None:
|
||||
api_tags = [SERVE_APP_NAME_HUMP]
|
||||
super().__init__(
|
||||
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
|
||||
)
|
||||
self._db_manager: Optional[DatabaseManager] = None
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
if self._app_has_initiated:
|
||||
return
|
||||
self._system_app = system_app
|
||||
self._system_app.app.include_router(
|
||||
router, prefix=self._api_prefix, tags=self._api_tags
|
||||
)
|
||||
init_endpoints(self._system_app)
|
||||
self._app_has_initiated = True
|
||||
|
||||
def on_init(self):
|
||||
"""Called when init the application.
|
||||
|
||||
You can do some initialization here. You can't get other components here because they may be not initialized yet
|
||||
"""
|
||||
# import your own module here to ensure the module is loaded before the application starts
|
||||
from .models.models import KnowledgeSpaceEntity
|
||||
|
||||
def before_start(self):
|
||||
"""Called before the start of the application."""
|
||||
# TODO: Your code here
|
||||
self._db_manager = self.create_or_get_db_manager()
|
0
dbgpt/serve/rag/service/__init__.py
Normal file
0
dbgpt/serve/rag/service/__init__.py
Normal file
522
dbgpt/serve/rag/service/service.py
Normal file
522
dbgpt/serve/rag/service/service.py
Normal file
@@ -0,0 +1,522 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity
|
||||
from dbgpt.app.knowledge.document_db import (
|
||||
KnowledgeDocumentDao,
|
||||
KnowledgeDocumentEntity,
|
||||
)
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
)
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.embedding import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge import ChunkStrategy, KnowledgeFactory, KnowledgeType
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.dbgpts.loader import DBGPTsLoader
|
||||
from dbgpt.util.executor_utils import ExecutorFactory
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
from ..api.schemas import (
|
||||
DocumentServeRequest,
|
||||
DocumentServeResponse,
|
||||
KnowledgeSyncRequest,
|
||||
SpaceServeRequest,
|
||||
SpaceServeResponse,
|
||||
)
|
||||
from ..assembler.embedding import EmbeddingAssembler
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class SyncStatus(Enum):
|
||||
TODO = "TODO"
|
||||
FAILED = "FAILED"
|
||||
RUNNING = "RUNNING"
|
||||
FINISHED = "FINISHED"
|
||||
|
||||
|
||||
class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeResponse]):
|
||||
"""The service class for Flow"""
|
||||
|
||||
name = SERVE_SERVICE_COMPONENT_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_app: SystemApp,
|
||||
dao: Optional[KnowledgeSpaceDao] = None,
|
||||
document_dao: Optional[KnowledgeDocumentDao] = None,
|
||||
chunk_dao: Optional[DocumentChunkDao] = None,
|
||||
):
|
||||
self._system_app = None
|
||||
self._dao: KnowledgeSpaceDao = dao
|
||||
self._document_dao: KnowledgeDocumentDao = document_dao
|
||||
self._chunk_dao: DocumentChunkDao = chunk_dao
|
||||
self._dag_manager: Optional[DAGManager] = None
|
||||
self._dbgpts_loader: Optional[DBGPTsLoader] = None
|
||||
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp) -> None:
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
self._dao = self._dao or KnowledgeSpaceDao()
|
||||
self._document_dao = self._document_dao or KnowledgeDocumentDao()
|
||||
self._chunk_dao = self._chunk_dao or DocumentChunkDao()
|
||||
self._system_app = system_app
|
||||
|
||||
def before_start(self):
|
||||
"""Execute before the application starts"""
|
||||
|
||||
def after_start(self):
|
||||
"""Execute after the application starts"""
|
||||
|
||||
@property
|
||||
def dao(
|
||||
self,
|
||||
) -> BaseDao[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeResponse]:
|
||||
"""Returns the internal DAO."""
|
||||
return self._dao
|
||||
|
||||
@property
|
||||
def config(self) -> ServeConfig:
|
||||
"""Returns the internal ServeConfig."""
|
||||
return self._serve_config
|
||||
|
||||
def create_space(self, request: SpaceServeRequest) -> SpaceServeResponse:
|
||||
"""Create a new Space entity
|
||||
|
||||
Args:
|
||||
request (KnowledgeSpaceRequest): The request
|
||||
|
||||
Returns:
|
||||
SpaceServeResponse: The response
|
||||
"""
|
||||
space = self.get(request)
|
||||
if space is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"space name:{request.name} have already named",
|
||||
)
|
||||
return self._dao.create_knowledge_space(request)
|
||||
|
||||
def update_space(self, request: SpaceServeRequest) -> SpaceServeResponse:
|
||||
"""Create a new Space entity
|
||||
|
||||
Args:
|
||||
request (KnowledgeSpaceRequest): The request
|
||||
|
||||
Returns:
|
||||
SpaceServeResponse: The response
|
||||
"""
|
||||
spaces = self._dao.get_knowledge_space(
|
||||
KnowledgeSpaceEntity(id=request.id, name=request.name)
|
||||
)
|
||||
if len(spaces) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"no space name named {request.name}",
|
||||
)
|
||||
space = spaces[0]
|
||||
query_request = {"id": space.id}
|
||||
update_obj = self._dao.update(query_request, update_request=request)
|
||||
return update_obj
|
||||
|
||||
async def create_document(
|
||||
self, request: DocumentServeRequest
|
||||
) -> SpaceServeResponse:
|
||||
"""Create a new document entity
|
||||
|
||||
Args:
|
||||
request (KnowledgeSpaceRequest): The request
|
||||
|
||||
Returns:
|
||||
SpaceServeResponse: The response
|
||||
"""
|
||||
space = self.get({"id": request.space_id})
|
||||
if space is None:
|
||||
raise Exception(f"space id:{request.space_id} not found")
|
||||
query = KnowledgeDocumentEntity(doc_name=request.doc_name, space=space.name)
|
||||
documents = self._document_dao.get_knowledge_documents(query)
|
||||
if len(documents) > 0:
|
||||
raise Exception(f"document name:{request.doc_name} have already named")
|
||||
if request.doc_file and request.doc_type == KnowledgeType.DOCUMENT.name:
|
||||
doc_file = request.doc_file
|
||||
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space.name)):
|
||||
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space.name))
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(
|
||||
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space.name)
|
||||
)
|
||||
with os.fdopen(tmp_fd, "wb") as tmp:
|
||||
tmp.write(await request.doc_file.read())
|
||||
shutil.move(
|
||||
tmp_path,
|
||||
os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space.name, doc_file.filename),
|
||||
)
|
||||
request.content = os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, space.name, doc_file.filename
|
||||
)
|
||||
document = KnowledgeDocumentEntity(
|
||||
doc_name=request.doc_name,
|
||||
doc_type=request.doc_type,
|
||||
space=space.name,
|
||||
chunk_size=0,
|
||||
status=SyncStatus.TODO.name,
|
||||
last_sync=datetime.now(),
|
||||
content=request.content,
|
||||
result="",
|
||||
)
|
||||
doc_id = self._document_dao.create_knowledge_document(document)
|
||||
if doc_id is None:
|
||||
raise Exception(f"create document failed, {request.doc_name}")
|
||||
return doc_id
|
||||
|
||||
def sync_document(self, requests: List[KnowledgeSyncRequest]) -> List:
|
||||
"""Create a new document entity
|
||||
|
||||
Args:
|
||||
request (KnowledgeSpaceRequest): The request
|
||||
|
||||
Returns:
|
||||
SpaceServeResponse: The response
|
||||
"""
|
||||
doc_ids = []
|
||||
for sync_request in requests:
|
||||
space_id = sync_request.space_id
|
||||
docs = self._document_dao.documents_by_ids([sync_request.doc_id])
|
||||
if len(docs) == 0:
|
||||
raise Exception(
|
||||
f"there are document called, doc_id: {sync_request.doc_id}"
|
||||
)
|
||||
doc = docs[0]
|
||||
if (
|
||||
doc.status == SyncStatus.RUNNING.name
|
||||
or doc.status == SyncStatus.FINISHED.name
|
||||
):
|
||||
raise Exception(
|
||||
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
|
||||
)
|
||||
chunk_parameters = sync_request.chunk_parameters
|
||||
if chunk_parameters.chunk_strategy != ChunkStrategy.CHUNK_BY_SIZE.name:
|
||||
space_context = self.get_space_context(space_id)
|
||||
chunk_parameters.chunk_size = (
|
||||
CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_size"])
|
||||
)
|
||||
chunk_parameters.chunk_overlap = (
|
||||
CFG.KNOWLEDGE_CHUNK_OVERLAP
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_overlap"])
|
||||
)
|
||||
self._sync_knowledge_document(space_id, doc, chunk_parameters)
|
||||
doc_ids.append(doc.id)
|
||||
return doc_ids
|
||||
|
||||
def get(self, request: QUERY_SPEC) -> Optional[SpaceServeResponse]:
|
||||
"""Get a Flow entity
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
|
||||
Returns:
|
||||
SpaceServeResponse: The response
|
||||
"""
|
||||
# TODO: implement your own logic here
|
||||
# Build the query request from the request
|
||||
query_request = request
|
||||
return self._dao.get_one(query_request)
|
||||
|
||||
def get_document(self, request: QUERY_SPEC) -> Optional[SpaceServeResponse]:
|
||||
"""Get a Flow entity
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
|
||||
Returns:
|
||||
SpaceServeResponse: The response
|
||||
"""
|
||||
# TODO: implement your own logic here
|
||||
# Build the query request from the request
|
||||
query_request = request
|
||||
return self._document_dao.get_one(query_request)
|
||||
|
||||
def delete(self, space_id: str) -> Optional[SpaceServeResponse]:
|
||||
"""Delete a Flow entity
|
||||
|
||||
Args:
|
||||
uid (str): The uid
|
||||
|
||||
Returns:
|
||||
SpaceServeResponse: The data after deletion
|
||||
"""
|
||||
|
||||
# TODO: implement your own logic here
|
||||
# Build the query request from the request
|
||||
query_request = {"id": space_id}
|
||||
space = self.get(query_request)
|
||||
if space is None:
|
||||
raise HTTPException(status_code=400, detail=f"Space {space_id} not found")
|
||||
config = VectorStoreConfig(name=space.name)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
)
|
||||
# delete vectors
|
||||
vector_store_connector.delete_vector_name(space.name)
|
||||
document_query = KnowledgeDocumentEntity(space=space.name)
|
||||
# delete chunks
|
||||
documents = self._document_dao.get_documents(document_query)
|
||||
for document in documents:
|
||||
self._chunk_dao.raw_delete(document.id)
|
||||
# delete documents
|
||||
self._document_dao.raw_delete(document_query)
|
||||
# delete space
|
||||
self._dao.delete(query_request)
|
||||
return space
|
||||
|
||||
def delete_document(self, document_id: str) -> Optional[DocumentServeResponse]:
|
||||
"""Delete a Flow entity
|
||||
|
||||
Args:
|
||||
uid (str): The uid
|
||||
|
||||
Returns:
|
||||
SpaceServeResponse: The data after deletion
|
||||
"""
|
||||
|
||||
query_request = {"id": document_id}
|
||||
docuemnt = self._document_dao.get_one(query_request)
|
||||
if docuemnt is None:
|
||||
raise Exception(f"there are no or more than one document {document_id}")
|
||||
vector_ids = docuemnt.vector_ids
|
||||
if vector_ids is not None:
|
||||
config = VectorStoreConfig(name=docuemnt.space)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
)
|
||||
# delete vector by ids
|
||||
vector_store_connector.delete_by_ids(vector_ids)
|
||||
# delete chunks
|
||||
self._chunk_dao.raw_delete(docuemnt.id)
|
||||
# delete document
|
||||
self._document_dao.raw_delete(docuemnt)
|
||||
return docuemnt
|
||||
|
||||
def get_list(self, request: SpaceServeRequest) -> List[SpaceServeResponse]:
|
||||
"""Get a list of Flow entities
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
|
||||
Returns:
|
||||
List[SpaceServeResponse]: The response
|
||||
"""
|
||||
# TODO: implement your own logic here
|
||||
# Build the query request from the request
|
||||
query_request = request
|
||||
return self.dao.get_list(query_request)
|
||||
|
||||
def get_list_by_page(
|
||||
self, request: QUERY_SPEC, page: int, page_size: int
|
||||
) -> PaginationResult[SpaceServeResponse]:
|
||||
"""Get a list of Flow entities by page
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
page (int): The page number
|
||||
page_size (int): The page size
|
||||
|
||||
Returns:
|
||||
List[SpaceServeResponse]: The response
|
||||
"""
|
||||
return self.dao.get_list_page(request, page, page_size)
|
||||
|
||||
def get_document_list(
|
||||
self, request: QUERY_SPEC, page: int, page_size: int
|
||||
) -> PaginationResult[SpaceServeResponse]:
|
||||
"""Get a list of Flow entities by page
|
||||
|
||||
Args:
|
||||
request (SpaceServeRequest): The request
|
||||
page (int): The page number
|
||||
page_size (int): The page size
|
||||
|
||||
Returns:
|
||||
List[SpaceServeResponse]: The response
|
||||
"""
|
||||
return self._document_dao.get_list_page(request, page, page_size)
|
||||
|
||||
def _batch_document_sync(
|
||||
self, space_id, sync_requests: List[KnowledgeSyncRequest]
|
||||
) -> List[int]:
|
||||
"""batch sync knowledge document chunk into vector store
|
||||
Args:
|
||||
- space: Knowledge Space Name
|
||||
- sync_requests: List[KnowledgeSyncRequest]
|
||||
Returns:
|
||||
- List[int]: document ids
|
||||
"""
|
||||
doc_ids = []
|
||||
for sync_request in sync_requests:
|
||||
docs = self._document_dao.documents_by_ids([sync_request.doc_id])
|
||||
if len(docs) == 0:
|
||||
raise Exception(
|
||||
f"there are document called, doc_id: {sync_request.doc_id}"
|
||||
)
|
||||
doc = docs[0]
|
||||
if (
|
||||
doc.status == SyncStatus.RUNNING.name
|
||||
or doc.status == SyncStatus.FINISHED.name
|
||||
):
|
||||
raise Exception(
|
||||
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
|
||||
)
|
||||
chunk_parameters = sync_request.chunk_parameters
|
||||
if chunk_parameters.chunk_strategy != ChunkStrategy.CHUNK_BY_SIZE.name:
|
||||
space_context = self.get_space_context(space_id)
|
||||
chunk_parameters.chunk_size = (
|
||||
CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_size"])
|
||||
)
|
||||
chunk_parameters.chunk_overlap = (
|
||||
CFG.KNOWLEDGE_CHUNK_OVERLAP
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_overlap"])
|
||||
)
|
||||
self._sync_knowledge_document(space_id, doc, chunk_parameters)
|
||||
doc_ids.append(doc.id)
|
||||
return doc_ids
|
||||
|
||||
def _sync_knowledge_document(
|
||||
self,
|
||||
space_id,
|
||||
doc: KnowledgeDocumentEntity,
|
||||
chunk_parameters: ChunkParameters,
|
||||
) -> List[Chunk]:
|
||||
"""sync knowledge document chunk into vector store"""
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embedding_fn = embedding_factory.create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
space = self.get({"id": space_id})
|
||||
config = VectorStoreConfig(
|
||||
name=space.name,
|
||||
embedding_fn=embedding_fn,
|
||||
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
|
||||
)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
)
|
||||
knowledge = KnowledgeFactory.create(
|
||||
datasource=doc.content,
|
||||
knowledge_type=KnowledgeType.get_by_value(doc.doc_type),
|
||||
)
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
chunk_docs = assembler.get_chunks()
|
||||
doc.status = SyncStatus.RUNNING.name
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
doc.gmt_modified = datetime.now()
|
||||
self._document_dao.update_knowledge_document(doc)
|
||||
executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc)
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
return chunk_docs
|
||||
|
||||
@trace("async_doc_embedding")
|
||||
def async_doc_embedding(self, assembler, chunk_docs, doc):
|
||||
"""async document embedding into vector db
|
||||
Args:
|
||||
- client: EmbeddingEngine Client
|
||||
- chunk_docs: List[Document]
|
||||
- doc: KnowledgeDocumentEntity
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||
)
|
||||
try:
|
||||
with root_tracer.start_span(
|
||||
"app.knowledge.assembler.persist",
|
||||
metadata={"doc": doc.doc_name, "chunks": len(chunk_docs)},
|
||||
):
|
||||
vector_ids = assembler.persist()
|
||||
doc.status = SyncStatus.FINISHED.name
|
||||
doc.result = "document embedding success"
|
||||
if vector_ids is not None:
|
||||
doc.vector_ids = ",".join(vector_ids)
|
||||
logger.info(f"async document embedding, success:{doc.doc_name}")
|
||||
# save chunk details
|
||||
chunk_entities = [
|
||||
DocumentChunkEntity(
|
||||
doc_name=doc.doc_name,
|
||||
doc_type=doc.doc_type,
|
||||
document_id=doc.id,
|
||||
content=chunk_doc.content,
|
||||
meta_info=str(chunk_doc.metadata),
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
for chunk_doc in chunk_docs
|
||||
]
|
||||
self._chunk_dao.create_documents_chunks(chunk_entities)
|
||||
except Exception as e:
|
||||
doc.status = SyncStatus.FAILED.name
|
||||
doc.result = "document embedding failed" + str(e)
|
||||
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
|
||||
return self._document_dao.update_knowledge_document(doc)
|
||||
|
||||
def get_space_context(self, space_id):
|
||||
"""get space contect
|
||||
Args:
|
||||
- space_name: space name
|
||||
"""
|
||||
space = self.get({"id": space_id})
|
||||
if space is None:
|
||||
raise Exception(
|
||||
f"have not found {space_id} space or found more than one space called {space_id}"
|
||||
)
|
||||
if space.context is not None:
|
||||
return json.loads(space.context)
|
||||
return None
|
0
dbgpt/serve/rag/tests/__init__.py
Normal file
0
dbgpt/serve/rag/tests/__init__.py
Normal file
Reference in New Issue
Block a user