mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 07:00:15 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
0
dbgpt/app/knowledge/__init__.py
Normal file
0
dbgpt/app/knowledge/__init__.py
Normal file
0
dbgpt/app/knowledge/_cli/__init__.py
Normal file
0
dbgpt/app/knowledge/_cli/__init__.py
Normal file
234
dbgpt/app/knowledge/_cli/knowledge_cli.py
Normal file
234
dbgpt/app/knowledge/_cli/knowledge_cli.py
Normal file
@@ -0,0 +1,234 @@
|
||||
import click
|
||||
import logging
|
||||
import os
|
||||
import functools
|
||||
|
||||
from dbgpt.configs.model_config import DATASETS_DIR
|
||||
|
||||
_DEFAULT_API_ADDRESS: str = "http://127.0.0.1:5000"
|
||||
API_ADDRESS: str = _DEFAULT_API_ADDRESS
|
||||
|
||||
logger = logging.getLogger("dbgpt_cli")
|
||||
|
||||
|
||||
@click.group("knowledge")
|
||||
@click.option(
|
||||
"--address",
|
||||
type=str,
|
||||
default=API_ADDRESS,
|
||||
required=False,
|
||||
show_default=True,
|
||||
help=(
|
||||
"Address of the Api server(If not set, try to read from environment variable: API_ADDRESS)."
|
||||
),
|
||||
)
|
||||
def knowledge_cli_group(address: str):
|
||||
"""Knowledge command line tool"""
|
||||
global API_ADDRESS
|
||||
if address == _DEFAULT_API_ADDRESS:
|
||||
address = os.getenv("API_ADDRESS", _DEFAULT_API_ADDRESS)
|
||||
API_ADDRESS = address
|
||||
|
||||
|
||||
def add_knowledge_options(func):
|
||||
@click.option(
|
||||
"--space_name",
|
||||
required=False,
|
||||
type=str,
|
||||
default="default",
|
||||
show_default=True,
|
||||
help="Your knowledge space name",
|
||||
)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@knowledge_cli_group.command()
|
||||
@add_knowledge_options
|
||||
@click.option(
|
||||
"--vector_store_type",
|
||||
required=False,
|
||||
type=str,
|
||||
default="Chroma",
|
||||
show_default=True,
|
||||
help="Vector store type.",
|
||||
)
|
||||
@click.option(
|
||||
"--local_doc_path",
|
||||
required=False,
|
||||
type=str,
|
||||
default=DATASETS_DIR,
|
||||
show_default=True,
|
||||
help="Your document directory or document file path.",
|
||||
)
|
||||
@click.option(
|
||||
"--skip_wrong_doc",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Skip wrong document.",
|
||||
)
|
||||
@click.option(
|
||||
"--overwrite",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Overwrite existing document(they has same name).",
|
||||
)
|
||||
@click.option(
|
||||
"--max_workers",
|
||||
required=False,
|
||||
type=int,
|
||||
default=None,
|
||||
help="The maximum number of threads that can be used to upload document.",
|
||||
)
|
||||
@click.option(
|
||||
"--pre_separator",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Preseparator, this separator is used for pre-splitting before the document is "
|
||||
"actually split by the text splitter. Preseparator are not included in the vectorized text. ",
|
||||
)
|
||||
@click.option(
|
||||
"--separator",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="This is the document separator. Currently, only one separator is supported.",
|
||||
)
|
||||
@click.option(
|
||||
"--chunk_size",
|
||||
required=False,
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum size of chunks to split.",
|
||||
)
|
||||
@click.option(
|
||||
"--chunk_overlap",
|
||||
required=False,
|
||||
type=int,
|
||||
default=None,
|
||||
help="Overlap in characters between chunks.",
|
||||
)
|
||||
def load(
|
||||
space_name: str,
|
||||
vector_store_type: str,
|
||||
local_doc_path: str,
|
||||
skip_wrong_doc: bool,
|
||||
overwrite: bool,
|
||||
max_workers: int,
|
||||
pre_separator: str,
|
||||
separator: str,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
):
|
||||
"""Load your local documents to DB-GPT"""
|
||||
from dbgpt.app.knowledge._cli.knowledge_client import knowledge_init
|
||||
|
||||
knowledge_init(
|
||||
API_ADDRESS,
|
||||
space_name,
|
||||
vector_store_type,
|
||||
local_doc_path,
|
||||
skip_wrong_doc,
|
||||
overwrite,
|
||||
max_workers,
|
||||
pre_separator,
|
||||
separator,
|
||||
chunk_size,
|
||||
chunk_overlap,
|
||||
)
|
||||
|
||||
|
||||
@knowledge_cli_group.command()
|
||||
@add_knowledge_options
|
||||
@click.option(
|
||||
"--doc_name",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="The document name you want to delete. If doc_name is None, this command will delete the whole space.",
|
||||
)
|
||||
@click.option(
|
||||
"-y",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Confirm your choice",
|
||||
)
|
||||
def delete(space_name: str, doc_name: str, y: bool):
|
||||
"""Delete your knowledge space or document in space"""
|
||||
from dbgpt.app.knowledge._cli.knowledge_client import knowledge_delete
|
||||
|
||||
knowledge_delete(API_ADDRESS, space_name, doc_name, confirm=y)
|
||||
|
||||
|
||||
@knowledge_cli_group.command()
|
||||
@click.option(
|
||||
"--space_name",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Your knowledge space name. If None, list all spaces",
|
||||
)
|
||||
@click.option(
|
||||
"--doc_id",
|
||||
required=False,
|
||||
type=int,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Your document id in knowledge space. If Not None, list all chunks in current document",
|
||||
)
|
||||
@click.option(
|
||||
"--page",
|
||||
required=False,
|
||||
type=int,
|
||||
default=1,
|
||||
show_default=True,
|
||||
help="The page for every query",
|
||||
)
|
||||
@click.option(
|
||||
"--page_size",
|
||||
required=False,
|
||||
type=int,
|
||||
default=20,
|
||||
show_default=True,
|
||||
help="The page size for every query",
|
||||
)
|
||||
@click.option(
|
||||
"--show_content",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Query the document content of chunks",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
required=False,
|
||||
type=click.Choice(["text", "html", "csv", "latex", "json"]),
|
||||
default="text",
|
||||
help="The output format",
|
||||
)
|
||||
def list(
|
||||
space_name: str,
|
||||
doc_id: int,
|
||||
page: int,
|
||||
page_size: int,
|
||||
show_content: bool,
|
||||
output: str,
|
||||
):
|
||||
"""List knowledge space"""
|
||||
from dbgpt.app.knowledge._cli.knowledge_client import knowledge_list
|
||||
|
||||
knowledge_list(
|
||||
API_ADDRESS, space_name, page, page_size, doc_id, show_content, output
|
||||
)
|
395
dbgpt/app/knowledge/_cli/knowledge_client.py
Normal file
395
dbgpt/app/knowledge/_cli/knowledge_client.py
Normal file
@@ -0,0 +1,395 @@
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
import logging
|
||||
|
||||
from urllib.parse import urljoin
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.app.knowledge.request.request import (
|
||||
KnowledgeQueryRequest,
|
||||
KnowledgeDocumentRequest,
|
||||
ChunkQueryRequest,
|
||||
DocumentQueryRequest,
|
||||
)
|
||||
|
||||
from dbgpt.rag.embedding_engine.knowledge_type import KnowledgeType
|
||||
from dbgpt.app.knowledge.request.request import DocumentSyncRequest
|
||||
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
|
||||
HTTP_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
logger = logging.getLogger("dbgpt_cli")
|
||||
|
||||
|
||||
class ApiClient:
|
||||
def __init__(self, api_address: str) -> None:
|
||||
self.api_address = api_address
|
||||
|
||||
def _handle_response(self, response):
|
||||
if 200 <= response.status_code <= 300:
|
||||
result = Result(**response.json())
|
||||
if not result.success:
|
||||
raise Exception(result.err_msg)
|
||||
return result.data
|
||||
else:
|
||||
raise Exception(
|
||||
f"Http request error, code: {response.status_code}, message: {response.text}"
|
||||
)
|
||||
|
||||
def _post(self, url: str, data=None):
|
||||
if not isinstance(data, dict):
|
||||
data = data.__dict__
|
||||
url = urljoin(self.api_address, url)
|
||||
logger.debug(f"Send request to {url}, data: {data}")
|
||||
response = requests.post(url, data=json.dumps(data), headers=HTTP_HEADERS)
|
||||
return self._handle_response(response)
|
||||
|
||||
|
||||
class KnowledgeApiClient(ApiClient):
|
||||
def __init__(self, api_address: str) -> None:
|
||||
super().__init__(api_address)
|
||||
|
||||
def space_add(self, request: KnowledgeSpaceRequest):
|
||||
try:
|
||||
return self._post("/knowledge/space/add", data=request)
|
||||
except Exception as e:
|
||||
if "have already named" in str(e):
|
||||
logger.warn(f"you have already named {request.name}")
|
||||
else:
|
||||
raise e
|
||||
|
||||
def space_delete(self, request: KnowledgeSpaceRequest):
|
||||
return self._post("/knowledge/space/delete", data=request)
|
||||
|
||||
def space_list(self, request: KnowledgeSpaceRequest):
|
||||
return self._post("/knowledge/space/list", data=request)
|
||||
|
||||
def document_add(self, space_name: str, request: KnowledgeDocumentRequest):
|
||||
url = f"/knowledge/{space_name}/document/add"
|
||||
return self._post(url, data=request)
|
||||
|
||||
def document_delete(self, space_name: str, request: KnowledgeDocumentRequest):
|
||||
url = f"/knowledge/{space_name}/document/delete"
|
||||
return self._post(url, data=request)
|
||||
|
||||
def document_list(self, space_name: str, query_request: DocumentQueryRequest):
|
||||
url = f"/knowledge/{space_name}/document/list"
|
||||
return self._post(url, data=query_request)
|
||||
|
||||
def document_upload(self, space_name, doc_name, doc_type, doc_file_path):
|
||||
"""Upload with multipart/form-data"""
|
||||
url = f"{self.api_address}/knowledge/{space_name}/document/upload"
|
||||
with open(doc_file_path, "rb") as f:
|
||||
files = {"doc_file": f}
|
||||
data = {"doc_name": doc_name, "doc_type": doc_type}
|
||||
response = requests.post(url, data=data, files=files)
|
||||
return self._handle_response(response)
|
||||
|
||||
def document_sync(self, space_name: str, request: DocumentSyncRequest):
|
||||
url = f"/knowledge/{space_name}/document/sync"
|
||||
return self._post(url, data=request)
|
||||
|
||||
def chunk_list(self, space_name: str, query_request: ChunkQueryRequest):
|
||||
url = f"/knowledge/{space_name}/chunk/list"
|
||||
return self._post(url, data=query_request)
|
||||
|
||||
def similar_query(self, vector_name: str, query_request: KnowledgeQueryRequest):
|
||||
url = f"/knowledge/{vector_name}/query"
|
||||
return self._post(url, data=query_request)
|
||||
|
||||
|
||||
def knowledge_init(
|
||||
api_address: str,
|
||||
space_name: str,
|
||||
vector_store_type: str,
|
||||
local_doc_path: str,
|
||||
skip_wrong_doc: bool,
|
||||
overwrite: bool,
|
||||
max_workers: int,
|
||||
pre_separator: str,
|
||||
separator: str,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
):
|
||||
client = KnowledgeApiClient(api_address)
|
||||
space = KnowledgeSpaceRequest()
|
||||
space.name = space_name
|
||||
space.desc = "DB-GPT cli"
|
||||
space.vector_type = vector_store_type
|
||||
space.owner = "DB-GPT"
|
||||
|
||||
# Create space
|
||||
logger.info(f"Create space: {space}")
|
||||
client.space_add(space)
|
||||
logger.info("Create space successfully")
|
||||
space_list = client.space_list(KnowledgeSpaceRequest(name=space.name))
|
||||
if len(space_list) != 1:
|
||||
raise Exception(f"List space {space.name} error")
|
||||
space = KnowledgeSpaceRequest(**space_list[0])
|
||||
|
||||
doc_ids = []
|
||||
|
||||
def upload(filename: str):
|
||||
try:
|
||||
logger.info(f"Begin upload document: {filename} to {space.name}")
|
||||
doc_id = None
|
||||
try:
|
||||
doc_id = client.document_upload(
|
||||
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
||||
)
|
||||
except Exception as ex:
|
||||
if overwrite and "have already named" in str(ex):
|
||||
logger.warn(
|
||||
f"Document {filename} already exist in space {space.name}, overwrite it"
|
||||
)
|
||||
client.document_delete(
|
||||
space.name, KnowledgeDocumentRequest(doc_name=filename)
|
||||
)
|
||||
doc_id = client.document_upload(
|
||||
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
||||
)
|
||||
else:
|
||||
raise ex
|
||||
sync_req = DocumentSyncRequest(doc_ids=[doc_id])
|
||||
if pre_separator:
|
||||
sync_req.pre_separator = pre_separator
|
||||
if separator:
|
||||
sync_req.separators = [separator]
|
||||
if chunk_size:
|
||||
sync_req.chunk_size = chunk_size
|
||||
if chunk_overlap:
|
||||
sync_req.chunk_overlap = chunk_overlap
|
||||
|
||||
client.document_sync(space.name, sync_req)
|
||||
return doc_id
|
||||
except Exception as e:
|
||||
if skip_wrong_doc:
|
||||
logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
|
||||
else:
|
||||
raise e
|
||||
|
||||
if not os.path.exists(local_doc_path):
|
||||
raise Exception(f"{local_doc_path} not exists")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
tasks = []
|
||||
file_names = []
|
||||
if os.path.isdir(local_doc_path):
|
||||
for root, _, files in os.walk(local_doc_path, topdown=False):
|
||||
for file in files:
|
||||
file_names.append(os.path.join(root, file))
|
||||
else:
|
||||
# Single file
|
||||
file_names.append(local_doc_path)
|
||||
|
||||
[tasks.append(pool.submit(upload, filename)) for filename in file_names]
|
||||
|
||||
doc_ids = [r.result() for r in as_completed(tasks)]
|
||||
doc_ids = list(filter(lambda x: x, doc_ids))
|
||||
if not doc_ids:
|
||||
logger.warn("Warning: no document to sync")
|
||||
return
|
||||
|
||||
|
||||
from prettytable import PrettyTable
|
||||
|
||||
|
||||
class _KnowledgeVisualizer:
|
||||
def __init__(self, api_address: str, out_format: str):
|
||||
self.client = KnowledgeApiClient(api_address)
|
||||
self.out_format = out_format
|
||||
self.out_kwargs = {}
|
||||
if out_format == "json":
|
||||
self.out_kwargs["ensure_ascii"] = False
|
||||
|
||||
def print_table(self, table):
|
||||
print(table.get_formatted_string(out_format=self.out_format, **self.out_kwargs))
|
||||
|
||||
def list_spaces(self):
|
||||
spaces = self.client.space_list(KnowledgeSpaceRequest())
|
||||
table = PrettyTable(
|
||||
["Space ID", "Space Name", "Vector Type", "Owner", "Description"],
|
||||
title="All knowledge spaces",
|
||||
)
|
||||
for sp in spaces:
|
||||
context = sp.get("context")
|
||||
table.add_row(
|
||||
[
|
||||
sp.get("id"),
|
||||
sp.get("name"),
|
||||
sp.get("vector_type"),
|
||||
sp.get("owner"),
|
||||
sp.get("desc"),
|
||||
]
|
||||
)
|
||||
self.print_table(table)
|
||||
|
||||
def list_documents(self, space_name: str, page: int, page_size: int):
|
||||
space_data = self.client.document_list(
|
||||
space_name, DocumentQueryRequest(page=page, page_size=page_size)
|
||||
)
|
||||
|
||||
space_table = PrettyTable(
|
||||
[
|
||||
"Space Name",
|
||||
"Total Documents",
|
||||
"Current Page",
|
||||
"Current Size",
|
||||
"Page Size",
|
||||
],
|
||||
title=f"Space {space_name} description",
|
||||
)
|
||||
space_table.add_row(
|
||||
[space_name, space_data["total"], page, len(space_data["data"]), page_size]
|
||||
)
|
||||
|
||||
table = PrettyTable(
|
||||
[
|
||||
"Space Name",
|
||||
"Document ID",
|
||||
"Document Name",
|
||||
"Type",
|
||||
"Chunks",
|
||||
"Last Sync",
|
||||
"Status",
|
||||
"Result",
|
||||
],
|
||||
title=f"Documents of space {space_name}",
|
||||
)
|
||||
for doc in space_data["data"]:
|
||||
table.add_row(
|
||||
[
|
||||
space_name,
|
||||
doc.get("id"),
|
||||
doc.get("doc_name"),
|
||||
doc.get("doc_type"),
|
||||
doc.get("chunk_size"),
|
||||
doc.get("last_sync"),
|
||||
doc.get("status"),
|
||||
doc.get("result"),
|
||||
]
|
||||
)
|
||||
if self.out_format == "text":
|
||||
self.print_table(space_table)
|
||||
print("")
|
||||
self.print_table(table)
|
||||
|
||||
def list_chunks(
|
||||
self,
|
||||
space_name: str,
|
||||
doc_id: int,
|
||||
page: int,
|
||||
page_size: int,
|
||||
show_content: bool,
|
||||
):
|
||||
doc_data = self.client.chunk_list(
|
||||
space_name,
|
||||
ChunkQueryRequest(document_id=doc_id, page=page, page_size=page_size),
|
||||
)
|
||||
|
||||
doc_table = PrettyTable(
|
||||
[
|
||||
"Space Name",
|
||||
"Document ID",
|
||||
"Total Chunks",
|
||||
"Current Page",
|
||||
"Current Size",
|
||||
"Page Size",
|
||||
],
|
||||
title=f"Document {doc_id} in {space_name} description",
|
||||
)
|
||||
doc_table.add_row(
|
||||
[
|
||||
space_name,
|
||||
doc_id,
|
||||
doc_data["total"],
|
||||
page,
|
||||
len(doc_data["data"]),
|
||||
page_size,
|
||||
]
|
||||
)
|
||||
|
||||
table = PrettyTable(
|
||||
["Space Name", "Document ID", "Document Name", "Content", "Meta Data"],
|
||||
title=f"chunks of document id {doc_id} in space {space_name}",
|
||||
)
|
||||
for chunk in doc_data["data"]:
|
||||
table.add_row(
|
||||
[
|
||||
space_name,
|
||||
doc_id,
|
||||
chunk.get("doc_name"),
|
||||
chunk.get("content") if show_content else "[Hidden]",
|
||||
chunk.get("meta_info"),
|
||||
]
|
||||
)
|
||||
if self.out_format == "text":
|
||||
self.print_table(doc_table)
|
||||
print("")
|
||||
self.print_table(table)
|
||||
|
||||
|
||||
def knowledge_list(
|
||||
api_address: str,
|
||||
space_name: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
doc_id: int,
|
||||
show_content: bool,
|
||||
out_format: str,
|
||||
):
|
||||
visualizer = _KnowledgeVisualizer(api_address, out_format)
|
||||
if not space_name:
|
||||
visualizer.list_spaces()
|
||||
elif not doc_id:
|
||||
visualizer.list_documents(space_name, page, page_size)
|
||||
else:
|
||||
visualizer.list_chunks(space_name, doc_id, page, page_size, show_content)
|
||||
|
||||
|
||||
def knowledge_delete(
|
||||
api_address: str, space_name: str, doc_name: str, confirm: bool = False
|
||||
):
|
||||
client = KnowledgeApiClient(api_address)
|
||||
space = KnowledgeSpaceRequest()
|
||||
space.name = space_name
|
||||
space_list = client.space_list(KnowledgeSpaceRequest(name=space.name))
|
||||
if not space_list:
|
||||
raise Exception(f"No knowledge space name {space_name}")
|
||||
|
||||
if not doc_name:
|
||||
if not confirm:
|
||||
# Confirm by user
|
||||
user_input = (
|
||||
input(
|
||||
f"Are you sure you want to delete the whole knowledge space {space_name}? Type 'yes' to confirm: "
|
||||
)
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
if user_input != "yes":
|
||||
logger.warn("Delete operation cancelled.")
|
||||
return
|
||||
client.space_delete(space)
|
||||
logger.info("Delete the whole knowledge space successfully!")
|
||||
else:
|
||||
if not confirm:
|
||||
# Confirm by user
|
||||
user_input = (
|
||||
input(
|
||||
f"Are you sure you want to delete the doucment {doc_name} in knowledge space {space_name}? Type 'yes' to confirm: "
|
||||
)
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
if user_input != "yes":
|
||||
logger.warn("Delete operation cancelled.")
|
||||
return
|
||||
client.document_delete(space_name, KnowledgeDocumentRequest(doc_name=doc_name))
|
||||
logger.info(
|
||||
f"Delete the doucment {doc_name} in knowledge space {space_name} successfully!"
|
||||
)
|
272
dbgpt/app/knowledge/api.py
Normal file
272
dbgpt/app/knowledge/api.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile, Form
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
)
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator
|
||||
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
from dbgpt.app.knowledge.request.request import (
|
||||
KnowledgeQueryRequest,
|
||||
KnowledgeQueryResponse,
|
||||
KnowledgeDocumentRequest,
|
||||
DocumentSyncRequest,
|
||||
ChunkQueryRequest,
|
||||
DocumentQueryRequest,
|
||||
SpaceArgumentRequest,
|
||||
EntityExtractRequest,
|
||||
DocumentSummaryRequest,
|
||||
)
|
||||
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.util.tracer import root_tracer, SpanType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
knowledge_space_service = KnowledgeService()
|
||||
|
||||
|
||||
@router.post("/knowledge/space/add")
|
||||
def space_add(request: KnowledgeSpaceRequest):
|
||||
print(f"/space/add params: {request}")
|
||||
try:
|
||||
knowledge_space_service.create_knowledge_space(request)
|
||||
return Result.succ([])
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"space add error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/space/list")
|
||||
def space_list(request: KnowledgeSpaceRequest):
|
||||
print(f"/space/list params:")
|
||||
try:
|
||||
return Result.succ(knowledge_space_service.get_knowledge_space(request))
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"space list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/space/delete")
|
||||
def space_delete(request: KnowledgeSpaceRequest):
|
||||
print(f"/space/delete params:")
|
||||
try:
|
||||
return Result.succ(knowledge_space_service.delete_space(request.name))
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"space list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/arguments")
|
||||
def arguments(space_name: str):
|
||||
print(f"/knowledge/space/arguments params:")
|
||||
try:
|
||||
return Result.succ(knowledge_space_service.arguments(space_name))
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"space list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/argument/save")
|
||||
def arguments_save(space_name: str, argument_request: SpaceArgumentRequest):
|
||||
print(f"/knowledge/space/argument/save params:")
|
||||
try:
|
||||
return Result.succ(
|
||||
knowledge_space_service.argument_save(space_name, argument_request)
|
||||
)
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"space list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/add")
|
||||
def document_add(space_name: str, request: KnowledgeDocumentRequest):
|
||||
print(f"/document/add params: {space_name}, {request}")
|
||||
try:
|
||||
return Result.succ(
|
||||
knowledge_space_service.create_knowledge_document(
|
||||
space=space_name, request=request
|
||||
)
|
||||
)
|
||||
# return Result.succ([])
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document add error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/list")
|
||||
def document_list(space_name: str, query_request: DocumentQueryRequest):
|
||||
print(f"/document/list params: {space_name}, {query_request}")
|
||||
try:
|
||||
return Result.succ(
|
||||
knowledge_space_service.get_knowledge_documents(space_name, query_request)
|
||||
)
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/delete")
|
||||
def document_delete(space_name: str, query_request: DocumentQueryRequest):
|
||||
print(f"/document/list params: {space_name}, {query_request}")
|
||||
try:
|
||||
return Result.succ(
|
||||
knowledge_space_service.delete_document(space_name, query_request.doc_name)
|
||||
)
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/upload")
|
||||
async def document_upload(
|
||||
space_name: str,
|
||||
doc_name: str = Form(...),
|
||||
doc_type: str = Form(...),
|
||||
doc_file: UploadFile = File(...),
|
||||
):
|
||||
print(f"/document/upload params: {space_name}")
|
||||
try:
|
||||
if doc_file:
|
||||
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name)):
|
||||
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name))
|
||||
# We can not move temp file in windows system when we open file in context of `with`
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(
|
||||
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name)
|
||||
)
|
||||
with os.fdopen(tmp_fd, "wb") as tmp:
|
||||
tmp.write(await doc_file.read())
|
||||
shutil.move(
|
||||
tmp_path,
|
||||
os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename),
|
||||
)
|
||||
request = KnowledgeDocumentRequest()
|
||||
request.doc_name = doc_name
|
||||
request.doc_type = doc_type
|
||||
request.content = os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
|
||||
)
|
||||
space_res = knowledge_space_service.get_knowledge_space(
|
||||
KnowledgeSpaceRequest(name=space_name)
|
||||
)
|
||||
if len(space_res) == 0:
|
||||
# create default space
|
||||
if "default" != space_name:
|
||||
raise Exception(f"you have not create your knowledge space.")
|
||||
knowledge_space_service.create_knowledge_space(
|
||||
KnowledgeSpaceRequest(
|
||||
name=space_name,
|
||||
desc="first db-gpt rag application",
|
||||
owner="dbgpt",
|
||||
)
|
||||
)
|
||||
return Result.succ(
|
||||
knowledge_space_service.create_knowledge_document(
|
||||
space=space_name, request=request
|
||||
)
|
||||
)
|
||||
return Result.failed(code="E000X", msg=f"doc_file is None")
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document add error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/sync")
|
||||
def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||
logger.info(f"Received params: {space_name}, {request}")
|
||||
try:
|
||||
knowledge_space_service.sync_knowledge_document(
|
||||
space_name=space_name, sync_request=request
|
||||
)
|
||||
return Result.succ([])
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document sync error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/chunk/list")
|
||||
def document_list(space_name: str, query_request: ChunkQueryRequest):
|
||||
print(f"/document/list params: {space_name}, {query_request}")
|
||||
try:
|
||||
return Result.succ(knowledge_space_service.get_document_chunks(query_request))
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document chunk list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{vector_name}/query")
|
||||
def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
||||
print(f"Received params: {space_name}, {query_request}")
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={"vector_store_name": space_name},
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
docs = client.similar_search(query_request.query, query_request.top_k)
|
||||
res = [
|
||||
KnowledgeQueryResponse(text=d.page_content, source=d.metadata["source"])
|
||||
for d in docs
|
||||
]
|
||||
return {"response": res}
|
||||
|
||||
|
||||
@router.post("/knowledge/document/summary")
|
||||
async def document_summary(request: DocumentSummaryRequest):
|
||||
print(f"/document/summary params: {request}")
|
||||
try:
|
||||
with root_tracer.start_span(
|
||||
"get_chat_instance", span_type=SpanType.CHAT, metadata=request
|
||||
):
|
||||
chat = await knowledge_space_service.document_summary(request=request)
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
if not chat.prompt_template.stream_out:
|
||||
return StreamingResponse(
|
||||
no_stream_generator(chat),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
stream_generator(chat, False, request.model_name),
|
||||
headers=headers,
|
||||
media_type="text/plain",
|
||||
)
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document summary error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/entity/extract")
|
||||
async def entity_extract(request: EntityExtractRequest):
|
||||
logger.info(f"Received params: {request}")
|
||||
try:
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt._private.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": request.text,
|
||||
"select_param": "entity",
|
||||
"model_name": request.model_name,
|
||||
}
|
||||
|
||||
res = await llm_chat_response_nostream(
|
||||
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
return Result.succ(res)
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"entity extract error {e}")
|
141
dbgpt/app/knowledge/chunk_db.py
Normal file
141
dbgpt/app/knowledge/chunk_db.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class DocumentChunkEntity(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True)
|
||||
document_id = Column(Integer)
|
||||
doc_name = Column(String(100))
|
||||
doc_type = Column(String(100))
|
||||
content = Column(Text)
|
||||
meta_info = Column(String(500))
|
||||
gmt_created = Column(DateTime)
|
||||
gmt_modified = Column(DateTime)
|
||||
|
||||
def __repr__(self):
|
||||
return f"DocumentChunkEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', document_id='{self.document_id}', content='{self.content}', meta_info='{self.meta_info}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
|
||||
|
||||
class DocumentChunkDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_documents_chunks(self, documents: List):
|
||||
session = self.get_session()
|
||||
docs = [
|
||||
DocumentChunkEntity(
|
||||
doc_name=document.doc_name,
|
||||
doc_type=document.doc_type,
|
||||
document_id=document.document_id,
|
||||
content=document.content or "",
|
||||
meta_info=document.meta_info or "",
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
for document in documents
|
||||
]
|
||||
session.add_all(docs)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_document_chunks(
|
||||
self, query: DocumentChunkEntity, page=1, page_size=20, document_ids=None
|
||||
):
|
||||
session = self.get_session()
|
||||
document_chunks = session.query(DocumentChunkEntity)
|
||||
if query.id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||
if query.document_id is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.document_id == query.document_id
|
||||
)
|
||||
if query.doc_type is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.doc_type == query.doc_type
|
||||
)
|
||||
if query.content is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.content == query.content
|
||||
)
|
||||
if query.doc_name is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.doc_name == query.doc_name
|
||||
)
|
||||
if query.meta_info is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.meta_info == query.meta_info
|
||||
)
|
||||
if document_ids is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.document_id.in_(document_ids)
|
||||
)
|
||||
|
||||
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.asc())
|
||||
document_chunks = document_chunks.offset((page - 1) * page_size).limit(
|
||||
page_size
|
||||
)
|
||||
result = document_chunks.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_document_chunks_count(self, query: DocumentChunkEntity):
|
||||
session = self.get_session()
|
||||
document_chunks = session.query(func.count(DocumentChunkEntity.id))
|
||||
if query.id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||
if query.document_id is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.document_id == query.document_id
|
||||
)
|
||||
if query.doc_type is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.doc_type == query.doc_type
|
||||
)
|
||||
if query.doc_name is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.doc_name == query.doc_name
|
||||
)
|
||||
if query.meta_info is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.meta_info == query.meta_info
|
||||
)
|
||||
count = document_chunks.scalar()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
def delete(self, document_id: int):
|
||||
session = self.get_session()
|
||||
if document_id is None:
|
||||
raise Exception("document_id is None")
|
||||
query = DocumentChunkEntity(document_id=document_id)
|
||||
knowledge_documents = session.query(DocumentChunkEntity)
|
||||
if query.document_id is not None:
|
||||
chunks = knowledge_documents.filter(
|
||||
DocumentChunkEntity.document_id == query.document_id
|
||||
)
|
||||
chunks.delete()
|
||||
session.commit()
|
||||
session.close()
|
214
dbgpt/app/knowledge/document_db.py
Normal file
214
dbgpt/app/knowledge/document_db.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class KnowledgeDocumentEntity(Base):
|
||||
__tablename__ = "knowledge_document"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True)
|
||||
doc_name = Column(String(100))
|
||||
doc_type = Column(String(100))
|
||||
space = Column(String(100))
|
||||
chunk_size = Column(Integer)
|
||||
status = Column(String(100))
|
||||
last_sync = Column(DateTime)
|
||||
content = Column(Text)
|
||||
result = Column(Text)
|
||||
vector_ids = Column(Text)
|
||||
summary = Column(Text)
|
||||
gmt_created = Column(DateTime)
|
||||
gmt_modified = Column(DateTime)
|
||||
|
||||
def __repr__(self):
|
||||
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', summary='{self.summary}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
|
||||
|
||||
class KnowledgeDocumentDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.get_session()
|
||||
knowledge_document = KnowledgeDocumentEntity(
|
||||
doc_name=document.doc_name,
|
||||
doc_type=document.doc_type,
|
||||
space=document.space,
|
||||
chunk_size=0.0,
|
||||
status=document.status,
|
||||
last_sync=document.last_sync,
|
||||
content=document.content or "",
|
||||
result=document.result or "",
|
||||
vector_ids=document.vector_ids,
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
session.add(knowledge_document)
|
||||
session.commit()
|
||||
doc_id = knowledge_document.id
|
||||
session.close()
|
||||
return doc_id
|
||||
|
||||
def get_knowledge_documents(self, query, page=1, page_size=20):
|
||||
session = self.get_session()
|
||||
print(f"current session:{session}")
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.id == query.id
|
||||
)
|
||||
if query.doc_name is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_name == query.doc_name
|
||||
)
|
||||
if query.doc_type is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_type == query.doc_type
|
||||
)
|
||||
if query.space is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.space == query.space
|
||||
)
|
||||
if query.status is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.status == query.status
|
||||
)
|
||||
|
||||
knowledge_documents = knowledge_documents.order_by(
|
||||
KnowledgeDocumentEntity.id.desc()
|
||||
)
|
||||
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(
|
||||
page_size
|
||||
)
|
||||
result = knowledge_documents.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_documents(self, query):
|
||||
session = self.get_session()
|
||||
print(f"current session:{session}")
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.id == query.id
|
||||
)
|
||||
if query.doc_name is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_name == query.doc_name
|
||||
)
|
||||
if query.doc_type is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_type == query.doc_type
|
||||
)
|
||||
if query.space is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.space == query.space
|
||||
)
|
||||
if query.status is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.status == query.status
|
||||
)
|
||||
|
||||
knowledge_documents = knowledge_documents.order_by(
|
||||
KnowledgeDocumentEntity.id.desc()
|
||||
)
|
||||
result = knowledge_documents.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_knowledge_documents_count_bulk(self, space_names):
|
||||
session = self.get_session()
|
||||
"""
|
||||
Perform a batch query to count the number of documents for each knowledge space.
|
||||
|
||||
Args:
|
||||
space_names: A list of knowledge space names to query for document counts.
|
||||
session: A SQLAlchemy session object.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each space name to its document count.
|
||||
"""
|
||||
counts_query = (
|
||||
session.query(
|
||||
KnowledgeDocumentEntity.space,
|
||||
func.count(KnowledgeDocumentEntity.id).label("document_count"),
|
||||
)
|
||||
.filter(KnowledgeDocumentEntity.space.in_(space_names))
|
||||
.group_by(KnowledgeDocumentEntity.space)
|
||||
)
|
||||
|
||||
results = counts_query.all()
|
||||
docs_count = {result.space: result.document_count for result in results}
|
||||
return docs_count
|
||||
|
||||
def get_knowledge_documents_count(self, query):
|
||||
session = self.get_session()
|
||||
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.id == query.id
|
||||
)
|
||||
if query.doc_name is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_name == query.doc_name
|
||||
)
|
||||
if query.doc_type is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_type == query.doc_type
|
||||
)
|
||||
if query.space is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.space == query.space
|
||||
)
|
||||
if query.status is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.status == query.status
|
||||
)
|
||||
count = knowledge_documents.scalar()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.get_session()
|
||||
updated_space = session.merge(document)
|
||||
session.commit()
|
||||
return updated_space.id
|
||||
|
||||
#
|
||||
def delete(self, query: KnowledgeDocumentEntity):
|
||||
session = self.get_session()
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.id == query.id
|
||||
)
|
||||
if query.doc_name is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_name == query.doc_name
|
||||
)
|
||||
if query.space is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.space == query.space
|
||||
)
|
||||
knowledge_documents.delete()
|
||||
session.commit()
|
||||
session.close()
|
0
dbgpt/app/knowledge/request/__init__.py
Normal file
0
dbgpt/app/knowledge/request/__init__.py
Normal file
124
dbgpt/app/knowledge/request/request.py
Normal file
124
dbgpt/app/knowledge/request/request.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from fastapi import UploadFile
|
||||
|
||||
|
||||
class KnowledgeQueryRequest(BaseModel):
|
||||
"""query: knowledge query"""
|
||||
|
||||
query: str
|
||||
"""top_k: return topK documents"""
|
||||
top_k: int
|
||||
|
||||
|
||||
class KnowledgeSpaceRequest(BaseModel):
|
||||
"""name: knowledge space name"""
|
||||
|
||||
name: str = None
|
||||
"""vector_type: vector type"""
|
||||
vector_type: str = None
|
||||
"""desc: description"""
|
||||
desc: str = None
|
||||
"""owner: owner"""
|
||||
owner: str = None
|
||||
|
||||
|
||||
class KnowledgeDocumentRequest(BaseModel):
|
||||
"""doc_name: doc path"""
|
||||
|
||||
doc_name: str = None
|
||||
"""doc_type: doc type"""
|
||||
doc_type: str = None
|
||||
"""content: content"""
|
||||
content: str = None
|
||||
"""content: content"""
|
||||
source: str = None
|
||||
|
||||
"""text_chunk_size: text_chunk_size"""
|
||||
# text_chunk_size: int
|
||||
|
||||
|
||||
class DocumentQueryRequest(BaseModel):
|
||||
"""doc_name: doc path"""
|
||||
|
||||
doc_name: str = None
|
||||
"""doc_type: doc type"""
|
||||
doc_type: str = None
|
||||
"""status: status"""
|
||||
status: str = None
|
||||
"""page: page"""
|
||||
page: int = 1
|
||||
"""page_size: page size"""
|
||||
page_size: int = 20
|
||||
|
||||
|
||||
class DocumentSyncRequest(BaseModel):
|
||||
"""Sync request"""
|
||||
|
||||
"""doc_ids: doc ids"""
|
||||
doc_ids: List
|
||||
|
||||
model_name: Optional[str] = None
|
||||
|
||||
"""Preseparator, this separator is used for pre-splitting before the document is actually split by the text splitter.
|
||||
Preseparator are not included in the vectorized text.
|
||||
"""
|
||||
pre_separator: Optional[str] = None
|
||||
|
||||
"""Custom separators"""
|
||||
separators: Optional[List[str]] = None
|
||||
|
||||
"""Custom chunk size"""
|
||||
chunk_size: Optional[int] = None
|
||||
|
||||
"""Custom chunk overlap"""
|
||||
chunk_overlap: Optional[int] = None
|
||||
|
||||
|
||||
class ChunkQueryRequest(BaseModel):
|
||||
"""id: id"""
|
||||
|
||||
id: int = None
|
||||
"""document_id: doc id"""
|
||||
document_id: int = None
|
||||
"""doc_name: doc path"""
|
||||
doc_name: str = None
|
||||
"""doc_type: doc type"""
|
||||
doc_type: str = None
|
||||
"""page: page"""
|
||||
page: int = 1
|
||||
"""page_size: page size"""
|
||||
page_size: int = 20
|
||||
|
||||
|
||||
class KnowledgeQueryResponse:
|
||||
"""source: knowledge reference source"""
|
||||
|
||||
source: str
|
||||
"""score: knowledge vector query similarity score"""
|
||||
score: float = 0.0
|
||||
"""text: raw text info"""
|
||||
text: str
|
||||
|
||||
|
||||
class SpaceArgumentRequest(BaseModel):
|
||||
"""argument: argument"""
|
||||
|
||||
argument: str
|
||||
|
||||
|
||||
class DocumentSummaryRequest(BaseModel):
|
||||
"""Sync request"""
|
||||
|
||||
"""doc_ids: doc ids"""
|
||||
doc_id: int
|
||||
model_name: str
|
||||
conv_uid: str
|
||||
|
||||
|
||||
class EntityExtractRequest(BaseModel):
|
||||
"""argument: argument"""
|
||||
|
||||
text: str
|
||||
model_name: str
|
44
dbgpt/app/knowledge/request/response.py
Normal file
44
dbgpt/app/knowledge/request/response.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
|
||||
class ChunkQueryResponse(BaseModel):
|
||||
"""data: data"""
|
||||
|
||||
data: List = None
|
||||
"""summary: document summary"""
|
||||
summary: str = None
|
||||
"""total: total size"""
|
||||
total: int = None
|
||||
"""page: current page"""
|
||||
page: int = None
|
||||
|
||||
|
||||
class DocumentQueryResponse(BaseModel):
|
||||
"""data: data"""
|
||||
|
||||
data: List = None
|
||||
"""total: total size"""
|
||||
total: int = None
|
||||
"""page: current page"""
|
||||
page: int = None
|
||||
|
||||
|
||||
class SpaceQueryResponse(BaseModel):
|
||||
"""data: data"""
|
||||
|
||||
id: int = None
|
||||
name: str = None
|
||||
"""vector_type: vector type"""
|
||||
vector_type: str = None
|
||||
"""desc: description"""
|
||||
desc: str = None
|
||||
"""context: context"""
|
||||
context: str = None
|
||||
"""owner: owner"""
|
||||
owner: str = None
|
||||
gmt_created: str = None
|
||||
gmt_modified: str = None
|
||||
"""doc_count: doc_count"""
|
||||
docs: int = None
|
641
dbgpt/app/knowledge/service.py
Normal file
641
dbgpt/app/knowledge/service.py
Normal file
@@ -0,0 +1,641 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
)
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
|
||||
from dbgpt.app.knowledge.chunk_db import (
|
||||
DocumentChunkEntity,
|
||||
DocumentChunkDao,
|
||||
)
|
||||
from dbgpt.app.knowledge.document_db import (
|
||||
KnowledgeDocumentDao,
|
||||
KnowledgeDocumentEntity,
|
||||
)
|
||||
from dbgpt.app.knowledge.space_db import (
|
||||
KnowledgeSpaceDao,
|
||||
KnowledgeSpaceEntity,
|
||||
)
|
||||
from dbgpt.app.knowledge.request.request import (
|
||||
KnowledgeSpaceRequest,
|
||||
KnowledgeDocumentRequest,
|
||||
DocumentQueryRequest,
|
||||
ChunkQueryRequest,
|
||||
SpaceArgumentRequest,
|
||||
DocumentSyncRequest,
|
||||
DocumentSummaryRequest,
|
||||
)
|
||||
from enum import Enum
|
||||
|
||||
from dbgpt.app.knowledge.request.response import (
|
||||
ChunkQueryResponse,
|
||||
DocumentQueryResponse,
|
||||
SpaceQueryResponse,
|
||||
)
|
||||
|
||||
knowledge_space_dao = KnowledgeSpaceDao()
|
||||
knowledge_document_dao = KnowledgeDocumentDao()
|
||||
document_chunk_dao = DocumentChunkDao()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class SyncStatus(Enum):
|
||||
TODO = "TODO"
|
||||
FAILED = "FAILED"
|
||||
RUNNING = "RUNNING"
|
||||
FINISHED = "FINISHED"
|
||||
|
||||
|
||||
# default summary max iteration call with llm.
|
||||
DEFAULT_SUMMARY_MAX_ITERATION = 5
|
||||
# default summary concurrency call with llm.
|
||||
DEFAULT_SUMMARY_CONCURRENCY_LIMIT = 3
|
||||
|
||||
|
||||
class KnowledgeService:
|
||||
"""KnowledgeService
|
||||
Knowledge Management Service:
|
||||
-knowledge_space management
|
||||
-knowledge_document management
|
||||
-embedding management
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def create_knowledge_space(self, request: KnowledgeSpaceRequest):
|
||||
"""create knowledge space
|
||||
Args:
|
||||
- request: KnowledgeSpaceRequest
|
||||
"""
|
||||
query = KnowledgeSpaceEntity(
|
||||
name=request.name,
|
||||
)
|
||||
spaces = knowledge_space_dao.get_knowledge_space(query)
|
||||
if len(spaces) > 0:
|
||||
raise Exception(f"space name:{request.name} have already named")
|
||||
knowledge_space_dao.create_knowledge_space(request)
|
||||
return True
|
||||
|
||||
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
|
||||
"""create knowledge document
|
||||
Args:
|
||||
- request: KnowledgeDocumentRequest
|
||||
"""
|
||||
query = KnowledgeDocumentEntity(doc_name=request.doc_name, space=space)
|
||||
documents = knowledge_document_dao.get_knowledge_documents(query)
|
||||
if len(documents) > 0:
|
||||
raise Exception(f"document name:{request.doc_name} have already named")
|
||||
document = KnowledgeDocumentEntity(
|
||||
doc_name=request.doc_name,
|
||||
doc_type=request.doc_type,
|
||||
space=space,
|
||||
chunk_size=0,
|
||||
status=SyncStatus.TODO.name,
|
||||
last_sync=datetime.now(),
|
||||
content=request.content,
|
||||
result="",
|
||||
)
|
||||
return knowledge_document_dao.create_knowledge_document(document)
|
||||
|
||||
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
|
||||
"""get knowledge space
|
||||
Args:
|
||||
- request: KnowledgeSpaceRequest
|
||||
"""
|
||||
query = KnowledgeSpaceEntity(
|
||||
name=request.name, vector_type=request.vector_type, owner=request.owner
|
||||
)
|
||||
spaces = knowledge_space_dao.get_knowledge_space(query)
|
||||
space_names = [space.name for space in spaces]
|
||||
docs_count = knowledge_document_dao.get_knowledge_documents_count_bulk(
|
||||
space_names
|
||||
)
|
||||
responses = []
|
||||
for space in spaces:
|
||||
res = SpaceQueryResponse()
|
||||
res.id = space.id
|
||||
res.name = space.name
|
||||
res.vector_type = space.vector_type
|
||||
res.desc = space.desc
|
||||
res.owner = space.owner
|
||||
res.gmt_created = space.gmt_created
|
||||
res.gmt_modified = space.gmt_modified
|
||||
res.context = space.context
|
||||
res.docs = docs_count.get(space.name, 0)
|
||||
responses.append(res)
|
||||
return responses
|
||||
|
||||
def arguments(self, space_name):
|
||||
"""show knowledge space arguments
|
||||
Args:
|
||||
- space_name: Knowledge Space Name
|
||||
"""
|
||||
query = KnowledgeSpaceEntity(name=space_name)
|
||||
spaces = knowledge_space_dao.get_knowledge_space(query)
|
||||
if len(spaces) != 1:
|
||||
raise Exception(f"there are no or more than one space called {space_name}")
|
||||
space = spaces[0]
|
||||
if space.context is None:
|
||||
context = self._build_default_context()
|
||||
else:
|
||||
context = space.context
|
||||
return json.loads(context)
|
||||
|
||||
def argument_save(self, space_name, argument_request: SpaceArgumentRequest):
|
||||
"""save argument
|
||||
Args:
|
||||
- space_name: Knowledge Space Name
|
||||
- argument_request: SpaceArgumentRequest
|
||||
"""
|
||||
query = KnowledgeSpaceEntity(name=space_name)
|
||||
spaces = knowledge_space_dao.get_knowledge_space(query)
|
||||
if len(spaces) != 1:
|
||||
raise Exception(f"there are no or more than one space called {space_name}")
|
||||
space = spaces[0]
|
||||
space.context = argument_request.argument
|
||||
return knowledge_space_dao.update_knowledge_space(space)
|
||||
|
||||
def get_knowledge_documents(self, space, request: DocumentQueryRequest):
|
||||
"""get knowledge documents
|
||||
Args:
|
||||
- space: Knowledge Space Name
|
||||
- request: DocumentQueryRequest
|
||||
"""
|
||||
query = KnowledgeDocumentEntity(
|
||||
doc_name=request.doc_name,
|
||||
doc_type=request.doc_type,
|
||||
space=space,
|
||||
status=request.status,
|
||||
)
|
||||
res = DocumentQueryResponse()
|
||||
res.data = knowledge_document_dao.get_knowledge_documents(
|
||||
query, page=request.page, page_size=request.page_size
|
||||
)
|
||||
res.total = knowledge_document_dao.get_knowledge_documents_count(query)
|
||||
res.page = request.page
|
||||
return res
|
||||
|
||||
def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
|
||||
"""sync knowledge document chunk into vector store
|
||||
Args:
|
||||
- space: Knowledge Space Name
|
||||
- sync_request: DocumentSyncRequest
|
||||
"""
|
||||
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.embedding_engine.pre_text_splitter import PreTextSplitter
|
||||
from langchain.text_splitter import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
SpacyTextSplitter,
|
||||
)
|
||||
|
||||
# import langchain is very very slow!!!
|
||||
|
||||
doc_ids = sync_request.doc_ids
|
||||
self.model_name = sync_request.model_name or CFG.LLM_MODEL
|
||||
for doc_id in doc_ids:
|
||||
query = KnowledgeDocumentEntity(
|
||||
id=doc_id,
|
||||
space=space_name,
|
||||
)
|
||||
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
|
||||
if (
|
||||
doc.status == SyncStatus.RUNNING.name
|
||||
or doc.status == SyncStatus.FINISHED.name
|
||||
):
|
||||
raise Exception(
|
||||
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
|
||||
)
|
||||
|
||||
space_context = self.get_space_context(space_name)
|
||||
chunk_size = (
|
||||
CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_size"])
|
||||
)
|
||||
chunk_overlap = (
|
||||
CFG.KNOWLEDGE_CHUNK_OVERLAP
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_overlap"])
|
||||
)
|
||||
if sync_request.chunk_size:
|
||||
chunk_size = sync_request.chunk_size
|
||||
if sync_request.chunk_overlap:
|
||||
chunk_overlap = sync_request.chunk_overlap
|
||||
separators = sync_request.separators or None
|
||||
if CFG.LANGUAGE == "en":
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
separators=separators,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
)
|
||||
else:
|
||||
if separators and len(separators) > 1:
|
||||
raise ValueError(
|
||||
"SpacyTextSplitter do not support multiple separators"
|
||||
)
|
||||
try:
|
||||
separator = "\n\n" if not separators else separators[0]
|
||||
text_splitter = SpacyTextSplitter(
|
||||
separator=separator,
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
except Exception:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
separators=separators,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
if sync_request.pre_separator:
|
||||
logger.info(f"Use preseparator, {sync_request.pre_separator}")
|
||||
text_splitter = PreTextSplitter(
|
||||
pre_separator=sync_request.pre_separator,
|
||||
text_splitter_impl=text_splitter,
|
||||
)
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
client = EmbeddingEngine(
|
||||
knowledge_source=doc.content,
|
||||
knowledge_type=doc.doc_type.upper(),
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={
|
||||
"vector_store_name": space_name,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
},
|
||||
text_splitter=text_splitter,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
chunk_docs = client.read()
|
||||
# update document status
|
||||
doc.status = SyncStatus.RUNNING.name
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
doc.gmt_modified = datetime.now()
|
||||
knowledge_document_dao.update_knowledge_document(doc)
|
||||
executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
# save chunk details
|
||||
chunk_entities = [
|
||||
DocumentChunkEntity(
|
||||
doc_name=doc.doc_name,
|
||||
doc_type=doc.doc_type,
|
||||
document_id=doc.id,
|
||||
content=chunk_doc.page_content,
|
||||
meta_info=str(chunk_doc.metadata),
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
for chunk_doc in chunk_docs
|
||||
]
|
||||
document_chunk_dao.create_documents_chunks(chunk_entities)
|
||||
|
||||
return doc.id
|
||||
|
||||
async def document_summary(self, request: DocumentSummaryRequest):
|
||||
"""get document summary
|
||||
Args:
|
||||
- request: DocumentSummaryRequest
|
||||
"""
|
||||
doc_query = KnowledgeDocumentEntity(id=request.doc_id)
|
||||
documents = knowledge_document_dao.get_documents(doc_query)
|
||||
if len(documents) != 1:
|
||||
raise Exception(f"can not found document for {request.doc_id}")
|
||||
document = documents[0]
|
||||
query = DocumentChunkEntity(
|
||||
document_id=request.doc_id,
|
||||
)
|
||||
chunks = document_chunk_dao.get_document_chunks(query, page=1, page_size=100)
|
||||
if len(chunks) == 0:
|
||||
raise Exception(f"can not found chunks for {request.doc_id}")
|
||||
from langchain.schema import Document
|
||||
|
||||
chunk_docs = [Document(page_content=chunk.content) for chunk in chunks]
|
||||
return await self.async_document_summary(
|
||||
model_name=request.model_name,
|
||||
chunk_docs=chunk_docs,
|
||||
doc=document,
|
||||
conn_uid=request.conv_uid,
|
||||
)
|
||||
|
||||
def update_knowledge_space(
|
||||
self, space_id: int, space_request: KnowledgeSpaceRequest
|
||||
):
|
||||
"""update knowledge space
|
||||
Args:
|
||||
- space_id: space id
|
||||
- space_request: KnowledgeSpaceRequest
|
||||
"""
|
||||
knowledge_space_dao.update_knowledge_space(space_id, space_request)
|
||||
|
||||
def delete_space(self, space_name: str):
|
||||
"""delete knowledge space
|
||||
Args:
|
||||
- space_name: knowledge space name
|
||||
"""
|
||||
query = KnowledgeSpaceEntity(name=space_name)
|
||||
spaces = knowledge_space_dao.get_knowledge_space(query)
|
||||
if len(spaces) == 0:
|
||||
raise Exception(f"delete error, no space name:{space_name} in database")
|
||||
space = spaces[0]
|
||||
vector_config = {}
|
||||
vector_config["vector_store_name"] = space.name
|
||||
vector_config["vector_store_type"] = CFG.VECTOR_STORE_TYPE
|
||||
vector_config["chroma_persist_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
vector_client = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE, ctx=vector_config
|
||||
)
|
||||
# delete vectors
|
||||
vector_client.delete_vector_name(space.name)
|
||||
document_query = KnowledgeDocumentEntity(space=space.name)
|
||||
# delete chunks
|
||||
documents = knowledge_document_dao.get_documents(document_query)
|
||||
for document in documents:
|
||||
document_chunk_dao.delete(document.id)
|
||||
# delete documents
|
||||
knowledge_document_dao.delete(document_query)
|
||||
# delete space
|
||||
return knowledge_space_dao.delete_knowledge_space(space)
|
||||
|
||||
def delete_document(self, space_name: str, doc_name: str):
|
||||
"""delete document
|
||||
Args:
|
||||
- space_name: knowledge space name
|
||||
- doc_name: doocument name
|
||||
"""
|
||||
document_query = KnowledgeDocumentEntity(doc_name=doc_name, space=space_name)
|
||||
documents = knowledge_document_dao.get_documents(document_query)
|
||||
if len(documents) != 1:
|
||||
raise Exception(f"there are no or more than one document called {doc_name}")
|
||||
vector_ids = documents[0].vector_ids
|
||||
if vector_ids is not None:
|
||||
vector_config = {}
|
||||
vector_config["vector_store_name"] = space_name
|
||||
vector_config["vector_store_type"] = CFG.VECTOR_STORE_TYPE
|
||||
vector_config["chroma_persist_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
vector_client = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE, ctx=vector_config
|
||||
)
|
||||
# delete vector by ids
|
||||
vector_client.delete_by_ids(vector_ids)
|
||||
# delete chunks
|
||||
document_chunk_dao.delete(documents[0].id)
|
||||
# delete document
|
||||
return knowledge_document_dao.delete(document_query)
|
||||
|
||||
def get_document_chunks(self, request: ChunkQueryRequest):
|
||||
"""get document chunks
|
||||
Args:
|
||||
- request: ChunkQueryRequest
|
||||
"""
|
||||
query = DocumentChunkEntity(
|
||||
id=request.id,
|
||||
document_id=request.document_id,
|
||||
doc_name=request.doc_name,
|
||||
doc_type=request.doc_type,
|
||||
)
|
||||
document_query = KnowledgeDocumentEntity(id=request.document_id)
|
||||
documents = knowledge_document_dao.get_documents(document_query)
|
||||
|
||||
res = ChunkQueryResponse()
|
||||
res.data = document_chunk_dao.get_document_chunks(
|
||||
query, page=request.page, page_size=request.page_size
|
||||
)
|
||||
res.summary = documents[0].summary
|
||||
res.total = document_chunk_dao.get_document_chunks_count(query)
|
||||
res.page = request.page
|
||||
return res
|
||||
|
||||
def async_knowledge_graph(self, chunk_docs, doc):
|
||||
"""async document extract triplets and save into graph db
|
||||
Args:
|
||||
- chunk_docs: List[Document]
|
||||
- doc: KnowledgeDocumentEntity
|
||||
"""
|
||||
logger.info(
|
||||
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
|
||||
)
|
||||
try:
|
||||
from dbgpt.rag.graph_engine.graph_factory import RAGGraphFactory
|
||||
|
||||
rag_engine = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||
).create()
|
||||
rag_engine.knowledge_graph(chunk_docs)
|
||||
doc.status = SyncStatus.FINISHED.name
|
||||
doc.result = "document build graph success"
|
||||
except Exception as e:
|
||||
doc.status = SyncStatus.FAILED.name
|
||||
doc.result = "document build graph failed" + str(e)
|
||||
logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
|
||||
return knowledge_document_dao.update_knowledge_document(doc)
|
||||
|
||||
async def async_document_summary(self, model_name, chunk_docs, doc, conn_uid):
|
||||
"""async document extract summary
|
||||
Args:
|
||||
- model_name: str
|
||||
- chunk_docs: List[Document]
|
||||
- doc: KnowledgeDocumentEntity
|
||||
"""
|
||||
texts = [doc.page_content for doc in chunk_docs]
|
||||
from dbgpt.util.prompt_util import PromptHelper
|
||||
|
||||
prompt_helper = PromptHelper()
|
||||
from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
|
||||
|
||||
texts = prompt_helper.repack(prompt_template=prompt.template, text_chunks=texts)
|
||||
logger.info(
|
||||
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
|
||||
)
|
||||
space_context = self.get_space_context(doc.space)
|
||||
if space_context and space_context.get("summary"):
|
||||
summary = await self._mapreduce_extract_summary(
|
||||
docs=texts,
|
||||
model_name=model_name,
|
||||
max_iteration=int(space_context["summary"]["max_iteration"]),
|
||||
concurrency_limit=int(space_context["summary"]["concurrency_limit"]),
|
||||
)
|
||||
else:
|
||||
summary = await self._mapreduce_extract_summary(
|
||||
docs=texts, model_name=model_name
|
||||
)
|
||||
return await self._llm_extract_summary(summary, conn_uid, model_name)
|
||||
|
||||
def async_doc_embedding(self, client, chunk_docs, doc):
|
||||
"""async document embedding into vector db
|
||||
Args:
|
||||
- client: EmbeddingEngine Client
|
||||
- chunk_docs: List[Document]
|
||||
- doc: KnowledgeDocumentEntity
|
||||
"""
|
||||
logger.info(
|
||||
f"async doc sync, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||
)
|
||||
try:
|
||||
vector_ids = client.knowledge_embedding_batch(chunk_docs)
|
||||
doc.status = SyncStatus.FINISHED.name
|
||||
doc.result = "document embedding success"
|
||||
if vector_ids is not None:
|
||||
doc.vector_ids = ",".join(vector_ids)
|
||||
logger.info(f"async document embedding, success:{doc.doc_name}")
|
||||
except Exception as e:
|
||||
doc.status = SyncStatus.FAILED.name
|
||||
doc.result = "document embedding failed" + str(e)
|
||||
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
|
||||
return knowledge_document_dao.update_knowledge_document(doc)
|
||||
|
||||
def _build_default_context(self):
|
||||
from dbgpt.app.scene.chat_knowledge.v1.prompt import (
|
||||
PROMPT_SCENE_DEFINE,
|
||||
_DEFAULT_TEMPLATE,
|
||||
)
|
||||
|
||||
context_template = {
|
||||
"embedding": {
|
||||
"topk": CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||
"recall_score": CFG.KNOWLEDGE_SEARCH_RECALL_SCORE,
|
||||
"recall_type": "TopK",
|
||||
"model": EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL].rsplit("/", 1)[-1],
|
||||
"chunk_size": CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
"chunk_overlap": CFG.KNOWLEDGE_CHUNK_OVERLAP,
|
||||
},
|
||||
"prompt": {
|
||||
"max_token": 2000,
|
||||
"scene": PROMPT_SCENE_DEFINE,
|
||||
"template": _DEFAULT_TEMPLATE,
|
||||
},
|
||||
"summary": {
|
||||
"max_iteration": DEFAULT_SUMMARY_MAX_ITERATION,
|
||||
"concurrency_limit": DEFAULT_SUMMARY_CONCURRENCY_LIMIT,
|
||||
},
|
||||
}
|
||||
context_template_string = json.dumps(context_template, indent=4)
|
||||
return context_template_string
|
||||
|
||||
def get_space_context(self, space_name):
|
||||
"""get space contect
|
||||
Args:
|
||||
- space_name: space name
|
||||
"""
|
||||
request = KnowledgeSpaceRequest()
|
||||
request.name = space_name
|
||||
spaces = self.get_knowledge_space(request)
|
||||
if len(spaces) != 1:
|
||||
raise Exception(
|
||||
f"have not found {space_name} space or found more than one space called {space_name}"
|
||||
)
|
||||
space = spaces[0]
|
||||
if space.context is not None:
|
||||
return json.loads(spaces[0].context)
|
||||
return None
|
||||
|
||||
async def _llm_extract_summary(
|
||||
self, doc: str, conn_uid: str, model_name: str = None
|
||||
):
|
||||
"""Extract triplets from text by llm
|
||||
Args:
|
||||
doc: Document
|
||||
conn_uid: str,chat conversation id
|
||||
model_name: str, model name
|
||||
Returns:
|
||||
chat: BaseChat, refine summary chat.
|
||||
"""
|
||||
from dbgpt.app.scene import ChatScene
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": conn_uid,
|
||||
"current_user_input": "",
|
||||
"select_param": doc,
|
||||
"model_name": model_name,
|
||||
"model_cache_enable": False,
|
||||
}
|
||||
executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import CHAT_FACTORY
|
||||
|
||||
chat = await blocking_func_to_async(
|
||||
executor,
|
||||
CHAT_FACTORY.get_implementation,
|
||||
ChatScene.ExtractRefineSummary.value(),
|
||||
**{"chat_param": chat_param},
|
||||
)
|
||||
return chat
|
||||
|
||||
async def _mapreduce_extract_summary(
|
||||
self,
|
||||
docs,
|
||||
model_name: str = None,
|
||||
max_iteration: int = 5,
|
||||
concurrency_limit: int = 3,
|
||||
):
|
||||
"""Extract summary by mapreduce mode
|
||||
map -> multi async call llm to generate summary
|
||||
reduce -> merge the summaries by map process
|
||||
Args:
|
||||
docs:List[str]
|
||||
model_name:model name str
|
||||
max_iteration:max iteration will call llm to summary
|
||||
concurrency_limit:the max concurrency threads to call llm
|
||||
Returns:
|
||||
Document: refine summary context document.
|
||||
"""
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt._private.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
tasks = []
|
||||
if len(docs) == 1:
|
||||
return docs[0]
|
||||
else:
|
||||
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
|
||||
for doc in docs[0:max_iteration]:
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": "",
|
||||
"select_param": doc,
|
||||
"model_name": model_name,
|
||||
"model_cache_enable": True,
|
||||
}
|
||||
tasks.append(
|
||||
llm_chat_response_nostream(
|
||||
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
)
|
||||
from dbgpt._private.chat_util import run_async_tasks
|
||||
|
||||
summary_iters = await run_async_tasks(
|
||||
tasks=tasks, concurrency_limit=concurrency_limit
|
||||
)
|
||||
summary_iters = list(
|
||||
filter(
|
||||
lambda content: "LLMServer Generate Error" not in content,
|
||||
summary_iters,
|
||||
)
|
||||
)
|
||||
from dbgpt.util.prompt_util import PromptHelper
|
||||
from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
|
||||
|
||||
prompt_helper = PromptHelper()
|
||||
summary_iters = prompt_helper.repack(
|
||||
prompt_template=prompt.template, text_chunks=summary_iters
|
||||
)
|
||||
return await self._mapreduce_extract_summary(
|
||||
summary_iters, model_name, max_iteration, concurrency_limit
|
||||
)
|
111
dbgpt/app/knowledge/space_db.py
Normal file
111
dbgpt/app/knowledge/space_db.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class KnowledgeSpaceEntity(Base):
|
||||
__tablename__ = "knowledge_space"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
"mysql_collate": "utf8mb4_unicode_ci",
|
||||
}
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100))
|
||||
vector_type = Column(String(100))
|
||||
desc = Column(String(100))
|
||||
owner = Column(String(100))
|
||||
context = Column(Text)
|
||||
gmt_created = Column(DateTime)
|
||||
gmt_modified = Column(DateTime)
|
||||
|
||||
def __repr__(self):
|
||||
return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}' context='{self.context}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
|
||||
|
||||
class KnowledgeSpaceDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
||||
session = self.get_session()
|
||||
knowledge_space = KnowledgeSpaceEntity(
|
||||
name=space.name,
|
||||
vector_type=CFG.VECTOR_STORE_TYPE,
|
||||
desc=space.desc,
|
||||
owner=space.owner,
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
session.add(knowledge_space)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
|
||||
session = self.get_session()
|
||||
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
||||
if query.id is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.id == query.id
|
||||
)
|
||||
if query.name is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.name == query.name
|
||||
)
|
||||
if query.vector_type is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.vector_type == query.vector_type
|
||||
)
|
||||
if query.desc is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.desc == query.desc
|
||||
)
|
||||
if query.owner is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.owner == query.owner
|
||||
)
|
||||
if query.gmt_created is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.gmt_created == query.gmt_created
|
||||
)
|
||||
if query.gmt_modified is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
|
||||
)
|
||||
|
||||
knowledge_spaces = knowledge_spaces.order_by(
|
||||
KnowledgeSpaceEntity.gmt_created.desc()
|
||||
)
|
||||
result = knowledge_spaces.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||
session = self.get_session()
|
||||
session.merge(space)
|
||||
session.commit()
|
||||
session.close()
|
||||
return True
|
||||
|
||||
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||
session = self.get_session()
|
||||
if space:
|
||||
session.delete(space)
|
||||
session.commit()
|
||||
session.close()
|
Reference in New Issue
Block a user