mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-21 09:28:39 +00:00
chore: Merge main codes
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from functools import cache
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from dbgpt.component import AppConfig
|
||||
from dbgpt.component import AppConfig, SystemApp
|
||||
from dbgpt.util import BaseParameters, RegisterParameters
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
@@ -44,3 +45,103 @@ class BaseServeConfig(BaseParameters, RegisterParameters):
|
||||
if k not in config_dict and k[len(global_prefix) :] in cls().__dict__:
|
||||
config_dict[k[len(global_prefix) :]] = v
|
||||
return cls(**config_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseGPTsAppMemoryConfig(BaseParameters, RegisterParameters):
|
||||
__type__ = "___memory_placeholder___"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferWindowGPTsAppMemoryConfig(BaseGPTsAppMemoryConfig):
|
||||
"""Buffer window memory configuration.
|
||||
|
||||
This configuration is used to control the buffer window memory.
|
||||
"""
|
||||
|
||||
__type__ = "window"
|
||||
keep_start_rounds: int = field(
|
||||
default=0,
|
||||
metadata={"help": _("The number of start rounds to keep in memory")},
|
||||
)
|
||||
keep_end_rounds: int = field(
|
||||
default=0,
|
||||
metadata={"help": _("The number of end rounds to keep in memory")},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenBufferGPTsAppMemoryConfig(BaseGPTsAppMemoryConfig):
|
||||
"""Token buffer memory configuration.
|
||||
|
||||
This configuration is used to control the token buffer memory.
|
||||
"""
|
||||
|
||||
__type__ = "token"
|
||||
max_token_limit: int = field(
|
||||
default=100 * 1024,
|
||||
metadata={"help": _("The max token limit. Default is 100k")},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTsAppCommonConfig(BaseParameters, RegisterParameters):
|
||||
__type_field__ = "name"
|
||||
top_k: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": _("The top k for LLM generation")},
|
||||
)
|
||||
top_p: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": _("The top p for LLM generation")},
|
||||
)
|
||||
temperature: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": _("The temperature for LLM generation")},
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": _("The max new tokens for LLM generation")},
|
||||
)
|
||||
name: Optional[str] = field(
|
||||
default=None, metadata={"help": _("The name of your app")}
|
||||
)
|
||||
memory: Optional[BaseGPTsAppMemoryConfig] = field(
|
||||
default=None, metadata={"help": _("The memory configuration")}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTsAppConfig(GPTsAppCommonConfig, BaseParameters):
|
||||
"""GPTs application configuration.
|
||||
|
||||
For global configuration, you can set the parameters here.
|
||||
"""
|
||||
|
||||
name: str = "app"
|
||||
configs: List[GPTsAppCommonConfig] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": _("The configs for specific app")},
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def parse_config(
|
||||
system_app: SystemApp,
|
||||
config_name: str,
|
||||
type_class: Optional[Type[GPTsAppCommonConfig]],
|
||||
):
|
||||
from dbgpt_app.config import ApplicationConfig
|
||||
|
||||
app_config = system_app.config.get_typed("app_config", ApplicationConfig)
|
||||
# Global config for the chat scene
|
||||
config = app_config.app
|
||||
|
||||
for custom_config in config.configs:
|
||||
if custom_config.name == config_name:
|
||||
return custom_config
|
||||
|
||||
if type_class is not None:
|
||||
return type_class.from_dict(config.to_dict(), ignore_extra_fields=True)
|
||||
|
||||
return config
|
||||
|
@@ -114,7 +114,8 @@ async def create(
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.create(request))
|
||||
res = await blocking_func_to_async(global_system_app, service.create, request)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -134,7 +135,8 @@ async def update(
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.update(request))
|
||||
res = await blocking_func_to_async(global_system_app, service.update, request)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -153,7 +155,7 @@ async def delete(
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
service.delete(datasource_id)
|
||||
await blocking_func_to_async(global_system_app, service.delete, datasource_id)
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@@ -173,7 +175,8 @@ async def query(
|
||||
Returns:
|
||||
List[ServeResponse]: The response
|
||||
"""
|
||||
return Result.succ(service.get(datasource_id))
|
||||
res = await blocking_func_to_async(global_system_app, service.get, datasource_id)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -194,7 +197,9 @@ async def query_page(
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
res = service.get_list(db_type=db_type)
|
||||
res = await blocking_func_to_async(
|
||||
global_system_app, service.get_list, db_type=db_type
|
||||
)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@@ -207,7 +212,8 @@ async def get_datasource_types(
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result[ResourceTypes]:
|
||||
"""Get the datasource types."""
|
||||
return Result.succ(service.datasource_types())
|
||||
res = await blocking_func_to_async(global_system_app, service.datasource_types)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
@@ -233,12 +233,8 @@ class Service(
|
||||
DatasourceServeResponse: The data after deletion
|
||||
"""
|
||||
db_config = self._dao.get_one({"id": datasource_id})
|
||||
vector_name = db_config.db_name + "_profile"
|
||||
vector_connector = self.storage_manager.create_vector_store(
|
||||
index_name=vector_name
|
||||
)
|
||||
vector_connector.delete_vector_name(vector_name)
|
||||
if db_config:
|
||||
self._db_summary_client.delete_db_profile(db_config.db_name)
|
||||
self._dao.delete({"id": datasource_id})
|
||||
return db_config
|
||||
|
||||
@@ -301,12 +297,18 @@ class Service(
|
||||
bool: The refresh result
|
||||
"""
|
||||
db_config = self._dao.get_one({"id": datasource_id})
|
||||
vector_name = db_config.db_name + "_profile"
|
||||
vector_connector = self.storage_manager.create_vector_store(
|
||||
index_name=vector_name
|
||||
)
|
||||
vector_connector.delete_vector_name(vector_name)
|
||||
self._db_summary_client.db_summary_embedding(
|
||||
db_config.db_name, db_config.db_type
|
||||
if not db_config:
|
||||
raise HTTPException(status_code=404, detail="datasource not found")
|
||||
|
||||
self._db_summary_client.delete_db_profile(db_config.db_name)
|
||||
|
||||
# async embedding
|
||||
executor = self._system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create() # type: ignore
|
||||
executor.submit(
|
||||
self._db_summary_client.db_summary_embedding,
|
||||
db_config.db_name,
|
||||
db_config.db_type,
|
||||
)
|
||||
return True
|
||||
|
@@ -14,7 +14,7 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.util import PaginationResult
|
||||
from dbgpt_ext.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt_serve.core import Result
|
||||
from dbgpt_serve.core import Result, blocking_func_to_async
|
||||
from dbgpt_serve.rag.api.schemas import (
|
||||
DocumentServeRequest,
|
||||
DocumentServeResponse,
|
||||
@@ -155,7 +155,9 @@ async def delete(
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.delete(space_id))
|
||||
# TODO: Delete the files in the space
|
||||
res = await blocking_func_to_async(global_system_app, service.delete, space_id)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -248,7 +250,10 @@ async def create_document(
|
||||
doc_file=doc_file,
|
||||
space_id=space_id,
|
||||
)
|
||||
return Result.succ(await service.create_document(request))
|
||||
res = await blocking_func_to_async(
|
||||
global_system_app, service.create_document, request
|
||||
)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -369,7 +374,11 @@ async def delete_document(
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.delete_document(document_id))
|
||||
# TODO: Delete the files of the document
|
||||
res = await blocking_func_to_async(
|
||||
global_system_app, service.delete_document, document_id
|
||||
)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
|
||||
|
@@ -2,8 +2,6 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional, cast
|
||||
@@ -13,9 +11,10 @@ from fastapi import HTTPException
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.configs import TAG_KEY_KNOWLEDGE_FACTORY_DOMAIN_TYPE
|
||||
from dbgpt.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
KNOWLEDGE_CACHE_ROOT_PATH,
|
||||
)
|
||||
from dbgpt.core import Chunk, LLMClient
|
||||
from dbgpt.core.interface.file import _SCHEMA, FileStorageClient
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory
|
||||
@@ -30,7 +29,7 @@ from dbgpt_app.knowledge.request.request import BusinessFieldType
|
||||
from dbgpt_ext.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt_ext.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt_ext.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt_serve.core import BaseService
|
||||
from dbgpt_serve.core import BaseService, blocking_func_to_async
|
||||
|
||||
from ..api.schemas import (
|
||||
ChunkServeRequest,
|
||||
@@ -117,6 +116,10 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
).create()
|
||||
return DefaultLLMClient(worker_manager, True)
|
||||
|
||||
def get_fs(self) -> FileStorageClient:
|
||||
"""Get the FileStorageClient instance"""
|
||||
return FileStorageClient.get_instance(self.system_app)
|
||||
|
||||
def create_space(self, request: SpaceServeRequest) -> SpaceServeResponse:
|
||||
"""Create a new Space entity
|
||||
|
||||
@@ -155,7 +158,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
update_obj = self._dao.update_knowledge_space(self._dao.from_request(request))
|
||||
return update_obj
|
||||
|
||||
async def create_document(self, request: DocumentServeRequest) -> str:
|
||||
def create_document(self, request: DocumentServeRequest) -> str:
|
||||
"""Create a new document entity
|
||||
|
||||
Args:
|
||||
@@ -173,20 +176,21 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
raise Exception(f"document name:{request.doc_name} have already named")
|
||||
if request.doc_file and request.doc_type == KnowledgeType.DOCUMENT.name:
|
||||
doc_file = request.doc_file
|
||||
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space.name)):
|
||||
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space.name))
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(
|
||||
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space.name)
|
||||
)
|
||||
with os.fdopen(tmp_fd, "wb") as tmp:
|
||||
tmp.write(await request.doc_file.read())
|
||||
shutil.move(
|
||||
tmp_path,
|
||||
os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space.name, doc_file.filename),
|
||||
)
|
||||
request.content = os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, space.name, doc_file.filename
|
||||
safe_filename = os.path.basename(doc_file.filename)
|
||||
custom_metadata = {
|
||||
"space_name": space.name,
|
||||
"doc_name": doc_file.filename,
|
||||
"doc_type": request.doc_type,
|
||||
}
|
||||
bucket = "dbgpt_knowledge_file"
|
||||
file_uri = self.get_fs().save_file(
|
||||
bucket,
|
||||
safe_filename,
|
||||
doc_file.file,
|
||||
storage_type="distributed",
|
||||
custom_metadata=custom_metadata,
|
||||
)
|
||||
request.content = file_uri
|
||||
document = KnowledgeDocumentEntity(
|
||||
doc_name=request.doc_name,
|
||||
doc_type=request.doc_type,
|
||||
@@ -490,28 +494,58 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
storage_connector = self.storage_manager.get_storage_connector(
|
||||
space.name, space.vector_type
|
||||
)
|
||||
knowledge_content = doc.content
|
||||
if (
|
||||
doc.doc_type == KnowledgeType.DOCUMENT.value
|
||||
and knowledge_content.startswith(_SCHEMA)
|
||||
):
|
||||
logger.info(
|
||||
f"Download file from file storage, doc: {doc.doc_name}, file url: "
|
||||
f"{doc.content}"
|
||||
)
|
||||
local_file_path, file_meta = await blocking_func_to_async(
|
||||
self.system_app,
|
||||
self.get_fs().download_file,
|
||||
knowledge_content,
|
||||
dest_dir=KNOWLEDGE_CACHE_ROOT_PATH,
|
||||
)
|
||||
logger.info(f"Downloaded file to {local_file_path}")
|
||||
knowledge_content = local_file_path
|
||||
knowledge = None
|
||||
if not space.domain_type or (
|
||||
space.domain_type.lower() == BusinessFieldType.NORMAL.value.lower()
|
||||
):
|
||||
knowledge = KnowledgeFactory.create(
|
||||
datasource=doc.content,
|
||||
datasource=knowledge_content,
|
||||
knowledge_type=KnowledgeType.get_by_value(doc.doc_type),
|
||||
)
|
||||
doc.status = SyncStatus.RUNNING.name
|
||||
|
||||
doc.gmt_modified = datetime.now()
|
||||
self._document_dao.update_knowledge_document(doc)
|
||||
await blocking_func_to_async(
|
||||
self.system_app, self._document_dao.update_knowledge_document, doc
|
||||
)
|
||||
asyncio.create_task(
|
||||
self.async_doc_process(
|
||||
knowledge, chunk_parameters, storage_connector, doc, space
|
||||
knowledge,
|
||||
chunk_parameters,
|
||||
storage_connector,
|
||||
doc,
|
||||
space,
|
||||
knowledge_content,
|
||||
)
|
||||
)
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
|
||||
@trace("async_doc_process")
|
||||
async def async_doc_process(
|
||||
self, knowledge, chunk_parameters, storage_connector, doc, space
|
||||
self,
|
||||
knowledge,
|
||||
chunk_parameters,
|
||||
storage_connector,
|
||||
doc,
|
||||
space,
|
||||
knowledge_content: str,
|
||||
):
|
||||
"""async document process into storage
|
||||
Args:
|
||||
@@ -539,7 +573,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
f" and value: {space.domain_type}, dag: {dags[0]}"
|
||||
)
|
||||
db_name, chunk_docs = await end_task.call(
|
||||
{"file_path": doc.content, "space": doc.space}
|
||||
{"file_path": knowledge_content, "space": doc.space}
|
||||
)
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
vector_ids = [chunk.chunk_id for chunk in chunk_docs]
|
||||
|
@@ -115,7 +115,7 @@ async def test_create_document(service):
|
||||
service._document_dao.get_knowledge_documents = Mock(return_value=[])
|
||||
service._document_dao.create_knowledge_document = Mock(return_value="2")
|
||||
|
||||
response = await service.create_document(request)
|
||||
response = service.create_document(request)
|
||||
assert response == "2"
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user