mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 03:41:43 +00:00
refactor: RAG Refactor (#985)
Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
parent
90775aad50
commit
9ad70a2961
2
Makefile
2
Makefile
@ -17,7 +17,7 @@ setup: ## Set up the Python development environment
|
||||
$(VENV_BIN)/pip install -r requirements/lint-requirements.txt
|
||||
|
||||
testenv: setup ## Set up the Python test environment
|
||||
$(VENV_BIN)/pip install -e ".[simple_framework]"
|
||||
$(VENV_BIN)/pip install -e ".[default]"
|
||||
|
||||
.PHONY: fmt
|
||||
fmt: setup ## Format Python code
|
||||
|
@ -30,7 +30,7 @@ def initialize_components(
|
||||
system_app.register_instance(controller)
|
||||
|
||||
# Register global default RAGGraphFactory
|
||||
# from dbgpt.graph_engine.graph_factory import DefaultRAGGraphFactory
|
||||
# from dbgpt.graph.graph_factory import DefaultRAGGraphFactory
|
||||
|
||||
# system_app.register(DefaultRAGGraphFactory)
|
||||
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import Any, Type, TYPE_CHECKING
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
@ -14,10 +14,10 @@ from dbgpt.app.knowledge.request.request import (
|
||||
DocumentQueryRequest,
|
||||
)
|
||||
|
||||
from dbgpt.rag.embedding_engine.knowledge_type import KnowledgeType
|
||||
from dbgpt.app.knowledge.request.request import DocumentSyncRequest
|
||||
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||
|
||||
HTTP_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile, Form
|
||||
|
||||
@ -13,10 +14,10 @@ from dbgpt.configs.model_config import (
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator
|
||||
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
from dbgpt.app.knowledge.request.request import (
|
||||
KnowledgeQueryRequest,
|
||||
KnowledgeQueryResponse,
|
||||
@ -27,9 +28,14 @@ from dbgpt.app.knowledge.request.request import (
|
||||
SpaceArgumentRequest,
|
||||
EntityExtractRequest,
|
||||
DocumentSummaryRequest,
|
||||
KnowledgeSyncRequest,
|
||||
)
|
||||
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.tracer import root_tracer, SpanType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -103,6 +109,39 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
|
||||
return Result.failed(code="E000X", msg=f"document add error {e}")
|
||||
|
||||
|
||||
@router.get("/knowledge/document/chunkstrategies")
|
||||
def chunk_strategies():
|
||||
"""Get chunk strategies"""
|
||||
print(f"/document/chunkstrategies:")
|
||||
try:
|
||||
return Result.succ(
|
||||
[
|
||||
{
|
||||
"strategy": strategy.name,
|
||||
"name": strategy.value[2],
|
||||
"description": strategy.value[3],
|
||||
"parameters": strategy.value[1],
|
||||
"suffix": [
|
||||
knowledge.document_type().value
|
||||
for knowledge in KnowledgeFactory.subclasses()
|
||||
if strategy in knowledge.support_chunk_strategy()
|
||||
and knowledge.document_type() is not None
|
||||
],
|
||||
"type": set(
|
||||
[
|
||||
knowledge.type().value
|
||||
for knowledge in KnowledgeFactory.subclasses()
|
||||
if strategy in knowledge.support_chunk_strategy()
|
||||
]
|
||||
),
|
||||
}
|
||||
for strategy in ChunkStrategy
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"chunk strategies error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/list")
|
||||
def document_list(space_name: str, query_request: DocumentQueryRequest):
|
||||
print(f"/document/list params: {space_name}, {query_request}")
|
||||
@ -189,6 +228,18 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||
return Result.failed(code="E000X", msg=f"document sync error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/sync_batch")
|
||||
def batch_document_sync(space_name: str, request: List[KnowledgeSyncRequest]):
|
||||
logger.info(f"Received params: {space_name}, {request}")
|
||||
try:
|
||||
doc_ids = knowledge_space_service.batch_document_sync(
|
||||
space_name=space_name, sync_requests=request
|
||||
)
|
||||
return Result.succ({"tasks": doc_ids})
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document sync error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/chunk/list")
|
||||
def document_list(space_name: str, query_request: ChunkQueryRequest):
|
||||
print(f"/document/list params: {space_name}, {query_request}")
|
||||
@ -204,15 +255,23 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={"vector_store_name": space_name},
|
||||
embedding_factory=embedding_factory,
|
||||
config = VectorStoreConfig(
|
||||
name=space_name,
|
||||
embedding_fn=embedding_factory.create(
|
||||
EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
),
|
||||
)
|
||||
docs = client.similar_search(query_request.query, query_request.top_k)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
)
|
||||
retriever = EmbeddingRetriever(
|
||||
top_k=query_request.top_k, vector_store_connector=vector_store_connector
|
||||
)
|
||||
chunks = retriever.retrieve(query_request.query)
|
||||
res = [
|
||||
KnowledgeQueryResponse(text=d.page_content, source=d.metadata["source"])
|
||||
for d in docs
|
||||
KnowledgeQueryResponse(text=d.content, source=d.metadata["source"])
|
||||
for d in chunks
|
||||
]
|
||||
return {"response": res}
|
||||
|
||||
@ -254,7 +313,7 @@ async def entity_extract(request: EntityExtractRequest):
|
||||
logger.info(f"Received params: {request}")
|
||||
try:
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt._private.chat_util import llm_chat_response_nostream
|
||||
from dbgpt.util.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
chat_param = {
|
||||
|
@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||
|
||||
@ -51,6 +52,12 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return doc_id
|
||||
|
||||
def get_knowledge_documents(self, query, page=1, page_size=20):
|
||||
"""Get a list of documents that match the given query.
|
||||
Args:
|
||||
query: A KnowledgeDocumentEntity object containing the query parameters.
|
||||
page: The page number to return.
|
||||
page_size: The number of documents to return per page.
|
||||
"""
|
||||
session = self.get_raw_session()
|
||||
print(f"current session:{session}")
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
@ -85,6 +92,23 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def documents_by_ids(self, ids) -> List[KnowledgeDocumentEntity]:
|
||||
"""Get a list of documents by their IDs.
|
||||
Args:
|
||||
ids: A list of document IDs.
|
||||
Returns:
|
||||
A list of KnowledgeDocumentEntity objects.
|
||||
"""
|
||||
session = self.get_raw_session()
|
||||
print(f"current session:{session}")
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.id.in_(ids)
|
||||
)
|
||||
result = knowledge_documents.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_documents(self, query):
|
||||
session = self.get_raw_session()
|
||||
print(f"current session:{session}")
|
||||
|
@ -3,6 +3,8 @@ from typing import List, Optional
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from fastapi import UploadFile
|
||||
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
|
||||
|
||||
class KnowledgeQueryRequest(BaseModel):
|
||||
"""query: knowledge query"""
|
||||
@ -43,6 +45,8 @@ class DocumentQueryRequest(BaseModel):
|
||||
"""doc_name: doc path"""
|
||||
|
||||
doc_name: str = None
|
||||
"""doc_ids: doc ids"""
|
||||
doc_ids: Optional[List] = None
|
||||
"""doc_type: doc type"""
|
||||
doc_type: str = None
|
||||
"""status: status"""
|
||||
@ -76,6 +80,20 @@ class DocumentSyncRequest(BaseModel):
|
||||
chunk_overlap: Optional[int] = None
|
||||
|
||||
|
||||
class KnowledgeSyncRequest(BaseModel):
|
||||
"""Sync request"""
|
||||
|
||||
"""doc_ids: doc ids"""
|
||||
doc_id: int
|
||||
|
||||
"""model_name: model name"""
|
||||
model_name: Optional[str] = None
|
||||
|
||||
"""chunk_parameters: chunk parameters
|
||||
"""
|
||||
chunk_parameters: ChunkParameters
|
||||
|
||||
|
||||
class ChunkQueryRequest(BaseModel):
|
||||
"""id: id"""
|
||||
|
||||
|
@ -1,13 +1,26 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
from dbgpt.rag.text_splitter.text_splitter import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
SpacyTextSplitter,
|
||||
)
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
)
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
@ -32,6 +45,7 @@ from dbgpt.app.knowledge.request.request import (
|
||||
SpaceArgumentRequest,
|
||||
DocumentSyncRequest,
|
||||
DocumentSummaryRequest,
|
||||
KnowledgeSyncRequest,
|
||||
)
|
||||
from enum import Enum
|
||||
|
||||
@ -106,7 +120,10 @@ class KnowledgeService:
|
||||
content=request.content,
|
||||
result="",
|
||||
)
|
||||
return knowledge_document_dao.create_knowledge_document(document)
|
||||
doc_id = knowledge_document_dao.create_knowledge_document(document)
|
||||
if doc_id is None:
|
||||
raise Exception(f"create document failed, {request.doc_name}")
|
||||
return doc_id
|
||||
|
||||
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
|
||||
"""get knowledge space
|
||||
@ -171,36 +188,75 @@ class KnowledgeService:
|
||||
Args:
|
||||
- space: Knowledge Space Name
|
||||
- request: DocumentQueryRequest
|
||||
Returns:
|
||||
- res DocumentQueryResponse
|
||||
"""
|
||||
query = KnowledgeDocumentEntity(
|
||||
doc_name=request.doc_name,
|
||||
doc_type=request.doc_type,
|
||||
space=space,
|
||||
status=request.status,
|
||||
)
|
||||
res = DocumentQueryResponse()
|
||||
res.data = knowledge_document_dao.get_knowledge_documents(
|
||||
query, page=request.page, page_size=request.page_size
|
||||
)
|
||||
res.total = knowledge_document_dao.get_knowledge_documents_count(query)
|
||||
res.page = request.page
|
||||
if request.doc_ids and len(request.doc_ids) > 0:
|
||||
res.data = knowledge_document_dao.documents_by_ids(request.doc_ids)
|
||||
else:
|
||||
query = KnowledgeDocumentEntity(
|
||||
doc_name=request.doc_name,
|
||||
doc_type=request.doc_type,
|
||||
space=space,
|
||||
status=request.status,
|
||||
)
|
||||
res.data = knowledge_document_dao.get_knowledge_documents(
|
||||
query, page=request.page, page_size=request.page_size
|
||||
)
|
||||
res.total = knowledge_document_dao.get_knowledge_documents_count(query)
|
||||
res.page = request.page
|
||||
return res
|
||||
|
||||
def batch_document_sync(
|
||||
self, space_name, sync_requests: List[KnowledgeSyncRequest]
|
||||
) -> List[int]:
|
||||
"""batch sync knowledge document chunk into vector store
|
||||
Args:
|
||||
- space: Knowledge Space Name
|
||||
- sync_requests: List[KnowledgeSyncRequest]
|
||||
Returns:
|
||||
- List[int]: document ids
|
||||
"""
|
||||
doc_ids = []
|
||||
for sync_request in sync_requests:
|
||||
docs = knowledge_document_dao.documents_by_ids([sync_request.doc_id])
|
||||
if len(docs) == 0:
|
||||
raise Exception(
|
||||
f"there are document called, doc_id: {sync_request.doc_id}"
|
||||
)
|
||||
doc = docs[0]
|
||||
if (
|
||||
doc.status == SyncStatus.RUNNING.name
|
||||
or doc.status == SyncStatus.FINISHED.name
|
||||
):
|
||||
raise Exception(
|
||||
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
|
||||
)
|
||||
chunk_parameters = sync_request.chunk_parameters
|
||||
if "Automatic" == chunk_parameters.chunk_strategy:
|
||||
space_context = self.get_space_context(space_name)
|
||||
chunk_parameters.chunk_size = (
|
||||
CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_size"])
|
||||
)
|
||||
chunk_parameters.chunk_overlap = (
|
||||
CFG.KNOWLEDGE_CHUNK_OVERLAP
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_overlap"])
|
||||
)
|
||||
self._sync_knowledge_document(space_name, doc, chunk_parameters)
|
||||
doc_ids.append(doc.id)
|
||||
return doc_ids
|
||||
|
||||
def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
|
||||
"""sync knowledge document chunk into vector store
|
||||
Args:
|
||||
- space: Knowledge Space Name
|
||||
- sync_request: DocumentSyncRequest
|
||||
"""
|
||||
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.embedding_engine.pre_text_splitter import PreTextSplitter
|
||||
from langchain.text_splitter import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
SpacyTextSplitter,
|
||||
)
|
||||
|
||||
# import langchain is very very slow!!!
|
||||
from dbgpt.rag.text_splitter.pre_text_splitter import PreTextSplitter
|
||||
|
||||
doc_ids = sync_request.doc_ids
|
||||
self.model_name = sync_request.model_name or CFG.LLM_MODEL
|
||||
@ -234,6 +290,11 @@ class KnowledgeService:
|
||||
if sync_request.chunk_overlap:
|
||||
chunk_overlap = sync_request.chunk_overlap
|
||||
separators = sync_request.separators or None
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
|
||||
chunk_parameters = ChunkParameters(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
if CFG.LANGUAGE == "en":
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
separators=separators,
|
||||
@ -244,7 +305,7 @@ class KnowledgeService:
|
||||
else:
|
||||
if separators and len(separators) > 1:
|
||||
raise ValueError(
|
||||
"SpacyTextSplitter do not support multiple separators"
|
||||
"SpacyTextSplitter do not support multipsle separators"
|
||||
)
|
||||
try:
|
||||
separator = "\n\n" if not separators else separators[0]
|
||||
@ -266,48 +327,51 @@ class KnowledgeService:
|
||||
pre_separator=sync_request.pre_separator,
|
||||
text_splitter_impl=text_splitter,
|
||||
)
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
client = EmbeddingEngine(
|
||||
knowledge_source=doc.content,
|
||||
knowledge_type=doc.doc_type.upper(),
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={
|
||||
"vector_store_name": space_name,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
},
|
||||
text_splitter=text_splitter,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
chunk_docs = client.read()
|
||||
# update document status
|
||||
doc.status = SyncStatus.RUNNING.name
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
doc.gmt_modified = datetime.now()
|
||||
knowledge_document_dao.update_knowledge_document(doc)
|
||||
executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
# save chunk details
|
||||
chunk_entities = [
|
||||
DocumentChunkEntity(
|
||||
doc_name=doc.doc_name,
|
||||
doc_type=doc.doc_type,
|
||||
document_id=doc.id,
|
||||
content=chunk_doc.page_content,
|
||||
meta_info=str(chunk_doc.metadata),
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
for chunk_doc in chunk_docs
|
||||
]
|
||||
document_chunk_dao.create_documents_chunks(chunk_entities)
|
||||
|
||||
chunk_parameters.text_splitter = text_splitter
|
||||
self._sync_knowledge_document(space_name, doc, chunk_parameters)
|
||||
return doc.id
|
||||
|
||||
def _sync_knowledge_document(
|
||||
self,
|
||||
space_name,
|
||||
doc: KnowledgeDocumentEntity,
|
||||
chunk_parameters: ChunkParameters,
|
||||
) -> List[Chunk]:
|
||||
"""sync knowledge document chunk into vector store"""
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embedding_fn = embedding_factory.create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
)
|
||||
knowledge = KnowledgeFactory.create(
|
||||
datasource=doc.content,
|
||||
knowledge_type=KnowledgeType.get_by_value(doc.doc_type),
|
||||
)
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
chunk_docs = assembler.get_chunks()
|
||||
doc.status = SyncStatus.RUNNING.name
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
doc.gmt_modified = datetime.now()
|
||||
knowledge_document_dao.update_knowledge_document(doc)
|
||||
executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc)
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
return chunk_docs
|
||||
|
||||
async def document_summary(self, request: DocumentSummaryRequest):
|
||||
"""get document summary
|
||||
Args:
|
||||
@ -318,20 +382,46 @@ class KnowledgeService:
|
||||
if len(documents) != 1:
|
||||
raise Exception(f"can not found document for {request.doc_id}")
|
||||
document = documents[0]
|
||||
query = DocumentChunkEntity(
|
||||
document_id=request.doc_id,
|
||||
)
|
||||
chunks = document_chunk_dao.get_document_chunks(query, page=1, page_size=100)
|
||||
if len(chunks) == 0:
|
||||
raise Exception(f"can not found chunks for {request.doc_id}")
|
||||
from langchain.schema import Document
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
|
||||
chunk_docs = [Document(page_content=chunk.content) for chunk in chunks]
|
||||
return await self.async_document_summary(
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
chunk_parameters = ChunkParameters(
|
||||
chunk_strategy="CHUNK_BY_SIZE",
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=CFG.KNOWLEDGE_CHUNK_OVERLAP,
|
||||
)
|
||||
chunk_entities = document_chunk_dao.get_document_chunks(
|
||||
DocumentChunkEntity(document_id=document.id)
|
||||
)
|
||||
if (
|
||||
document.status not in [SyncStatus.RUNNING.name]
|
||||
and len(chunk_entities) == 0
|
||||
):
|
||||
self._sync_knowledge_document(
|
||||
space_name=document.space,
|
||||
doc=document,
|
||||
chunk_parameters=chunk_parameters,
|
||||
)
|
||||
knowledge = KnowledgeFactory.create(
|
||||
datasource=document.content,
|
||||
knowledge_type=KnowledgeType.get_by_value(document.doc_type),
|
||||
)
|
||||
assembler = SummaryAssembler(
|
||||
knowledge=knowledge,
|
||||
model_name=request.model_name,
|
||||
chunk_docs=chunk_docs,
|
||||
doc=document,
|
||||
conn_uid=request.conv_uid,
|
||||
llm_client=DefaultLLMClient(worker_manager=worker_manager),
|
||||
language=CFG.LANGUAGE,
|
||||
chunk_parameters=chunk_parameters,
|
||||
)
|
||||
summary = await assembler.generate_summary()
|
||||
|
||||
if len(assembler.get_chunks()) == 0:
|
||||
raise Exception(f"can not found chunks for {request.doc_id}")
|
||||
|
||||
return await self._llm_extract_summary(
|
||||
summary, request.conv_uid, request.model_name
|
||||
)
|
||||
|
||||
def update_knowledge_space(
|
||||
@ -354,15 +444,13 @@ class KnowledgeService:
|
||||
if len(spaces) == 0:
|
||||
raise Exception(f"delete error, no space name:{space_name} in database")
|
||||
space = spaces[0]
|
||||
vector_config = {}
|
||||
vector_config["vector_store_name"] = space.name
|
||||
vector_config["vector_store_type"] = CFG.VECTOR_STORE_TYPE
|
||||
vector_config["chroma_persist_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
vector_client = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE, ctx=vector_config
|
||||
config = VectorStoreConfig(name=space.name)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
)
|
||||
# delete vectors
|
||||
vector_client.delete_vector_name(space.name)
|
||||
vector_store_connector.delete_vector_name(space.name)
|
||||
document_query = KnowledgeDocumentEntity(space=space.name)
|
||||
# delete chunks
|
||||
documents = knowledge_document_dao.get_documents(document_query)
|
||||
@ -385,15 +473,13 @@ class KnowledgeService:
|
||||
raise Exception(f"there are no or more than one document called {doc_name}")
|
||||
vector_ids = documents[0].vector_ids
|
||||
if vector_ids is not None:
|
||||
vector_config = {}
|
||||
vector_config["vector_store_name"] = space_name
|
||||
vector_config["vector_store_type"] = CFG.VECTOR_STORE_TYPE
|
||||
vector_config["chroma_persist_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
vector_client = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE, ctx=vector_config
|
||||
config = VectorStoreConfig(name=space_name)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
)
|
||||
# delete vector by ids
|
||||
vector_client.delete_by_ids(vector_ids)
|
||||
vector_store_connector.delete_by_ids(vector_ids)
|
||||
# delete chunks
|
||||
document_chunk_dao.raw_delete(documents[0].id)
|
||||
# delete document
|
||||
@ -432,7 +518,7 @@ class KnowledgeService:
|
||||
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
|
||||
)
|
||||
try:
|
||||
from dbgpt.rag.graph_engine.graph_factory import RAGGraphFactory
|
||||
from dbgpt.rag.graph.graph_factory import RAGGraphFactory
|
||||
|
||||
rag_engine = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
|
||||
@ -446,54 +532,38 @@ class KnowledgeService:
|
||||
logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
|
||||
return knowledge_document_dao.update_knowledge_document(doc)
|
||||
|
||||
async def async_document_summary(self, model_name, chunk_docs, doc, conn_uid):
|
||||
"""async document extract summary
|
||||
Args:
|
||||
- model_name: str
|
||||
- chunk_docs: List[Document]
|
||||
- doc: KnowledgeDocumentEntity
|
||||
"""
|
||||
texts = [doc.page_content for doc in chunk_docs]
|
||||
from dbgpt.util.prompt_util import PromptHelper
|
||||
|
||||
prompt_helper = PromptHelper()
|
||||
from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
|
||||
|
||||
texts = prompt_helper.repack(prompt_template=prompt.template, text_chunks=texts)
|
||||
logger.info(
|
||||
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
|
||||
)
|
||||
space_context = self.get_space_context(doc.space)
|
||||
if space_context and space_context.get("summary"):
|
||||
summary = await self._mapreduce_extract_summary(
|
||||
docs=texts,
|
||||
model_name=model_name,
|
||||
max_iteration=int(space_context["summary"]["max_iteration"]),
|
||||
concurrency_limit=int(space_context["summary"]["concurrency_limit"]),
|
||||
)
|
||||
else:
|
||||
summary = await self._mapreduce_extract_summary(
|
||||
docs=texts, model_name=model_name
|
||||
)
|
||||
return await self._llm_extract_summary(summary, conn_uid, model_name)
|
||||
|
||||
def async_doc_embedding(self, client, chunk_docs, doc):
|
||||
def async_doc_embedding(self, assembler, chunk_docs, doc):
|
||||
"""async document embedding into vector db
|
||||
Args:
|
||||
- client: EmbeddingEngine Client
|
||||
- chunk_docs: List[Document]
|
||||
- doc: KnowledgeDocumentEntity
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"async doc sync, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||
)
|
||||
try:
|
||||
vector_ids = client.knowledge_embedding_batch(chunk_docs)
|
||||
vector_ids = assembler.persist()
|
||||
doc.status = SyncStatus.FINISHED.name
|
||||
doc.result = "document embedding success"
|
||||
if vector_ids is not None:
|
||||
doc.vector_ids = ",".join(vector_ids)
|
||||
logger.info(f"async document embedding, success:{doc.doc_name}")
|
||||
# save chunk details
|
||||
chunk_entities = [
|
||||
DocumentChunkEntity(
|
||||
doc_name=doc.doc_name,
|
||||
doc_type=doc.doc_type,
|
||||
document_id=doc.id,
|
||||
content=chunk_doc.content,
|
||||
meta_info=str(chunk_doc.metadata),
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
for chunk_doc in chunk_docs
|
||||
]
|
||||
document_chunk_dao.create_documents_chunks(chunk_entities)
|
||||
except Exception as e:
|
||||
doc.status = SyncStatus.FAILED.name
|
||||
doc.result = "document embedding failed" + str(e)
|
||||
@ -577,65 +647,3 @@ class KnowledgeService:
|
||||
**{"chat_param": chat_param},
|
||||
)
|
||||
return chat
|
||||
|
||||
async def _mapreduce_extract_summary(
|
||||
self,
|
||||
docs,
|
||||
model_name: str = None,
|
||||
max_iteration: int = 5,
|
||||
concurrency_limit: int = 3,
|
||||
):
|
||||
"""Extract summary by mapreduce mode
|
||||
map -> multi async call llm to generate summary
|
||||
reduce -> merge the summaries by map process
|
||||
Args:
|
||||
docs:List[str]
|
||||
model_name:model name str
|
||||
max_iteration:max iteration will call llm to summary
|
||||
concurrency_limit:the max concurrency threads to call llm
|
||||
Returns:
|
||||
Document: refine summary context document.
|
||||
"""
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt._private.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
tasks = []
|
||||
if len(docs) == 1:
|
||||
return docs[0]
|
||||
else:
|
||||
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
|
||||
for doc in docs[0:max_iteration]:
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": "",
|
||||
"select_param": doc,
|
||||
"model_name": model_name,
|
||||
"model_cache_enable": True,
|
||||
}
|
||||
tasks.append(
|
||||
llm_chat_response_nostream(
|
||||
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
)
|
||||
from dbgpt._private.chat_util import run_async_tasks
|
||||
|
||||
summary_iters = await run_async_tasks(
|
||||
tasks=tasks, concurrency_limit=concurrency_limit
|
||||
)
|
||||
summary_iters = list(
|
||||
filter(
|
||||
lambda content: "LLMServer Generate Error" not in content,
|
||||
summary_iters,
|
||||
)
|
||||
)
|
||||
from dbgpt.util.prompt_util import PromptHelper
|
||||
from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
|
||||
|
||||
prompt_helper = PromptHelper()
|
||||
summary_iters = prompt_helper.repack(
|
||||
prompt_template=prompt.template, text_chunks=summary_iters
|
||||
)
|
||||
return await self._mapreduce_extract_summary(
|
||||
summary_iters, model_name, max_iteration, concurrency_limit
|
||||
)
|
||||
|
@ -11,6 +11,7 @@ from dbgpt.component import ComponentType
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.core.interface.message import OnceConversation
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.util import get_or_create_event_loop
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
@ -58,6 +59,9 @@ class BaseChat(ABC):
|
||||
chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
|
||||
)
|
||||
self.llm_echo = False
|
||||
self.worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
self.model_cache_enable = chat_param.get("model_cache_enable", False)
|
||||
|
||||
### load prompt template
|
||||
@ -162,6 +166,10 @@ class BaseChat(ABC):
|
||||
"BaseChat.__call_base.prompt_template.format", metadata=metadata
|
||||
):
|
||||
current_prompt = self.prompt_template.format(**input_values)
|
||||
### prompt context token adapt according to llm max context length
|
||||
current_prompt = await self.prompt_context_token_adapt(
|
||||
prompt=current_prompt
|
||||
)
|
||||
self.current_message.add_system_message(current_prompt)
|
||||
|
||||
llm_messages = self.generate_llm_messages()
|
||||
@ -169,6 +177,7 @@ class BaseChat(ABC):
|
||||
# Not new server mode, we convert the message format(List[ModelMessage]) to list of dict
|
||||
# fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post
|
||||
llm_messages = list(map(lambda m: m.dict(), llm_messages))
|
||||
|
||||
payload = {
|
||||
"model": self.llm_model,
|
||||
"prompt": self.generate_llm_text(),
|
||||
@ -431,6 +440,39 @@ class BaseChat(ABC):
|
||||
return message.content
|
||||
return None
|
||||
|
||||
async def prompt_context_token_adapt(self, prompt) -> str:
|
||||
"""prompt token adapt according to llm max context length"""
|
||||
model_metadata = await self.worker_manager.get_model_metadata(
|
||||
{"model": self.llm_model}
|
||||
)
|
||||
current_token_count = await self.worker_manager.count_token(
|
||||
{"model": self.llm_model, "prompt": prompt}
|
||||
)
|
||||
if current_token_count == -1:
|
||||
logger.warning(
|
||||
"tiktoken not installed, please `pip install tiktoken` first"
|
||||
)
|
||||
template_define_token_count = 0
|
||||
if len(self.prompt_template.template_define) > 0:
|
||||
template_define_token_count = await self.worker_manager.count_token(
|
||||
{
|
||||
"model": self.llm_model,
|
||||
"prompt": self.prompt_template.template_define,
|
||||
}
|
||||
)
|
||||
current_token_count += template_define_token_count
|
||||
if (
|
||||
current_token_count + self.prompt_template.max_new_tokens
|
||||
) > model_metadata.context_length:
|
||||
prompt = prompt[
|
||||
: (
|
||||
model_metadata.context_length
|
||||
- self.prompt_template.max_new_tokens
|
||||
- template_define_token_count
|
||||
)
|
||||
]
|
||||
return prompt
|
||||
|
||||
def generate(self, p) -> str:
|
||||
"""
|
||||
generate context for LLM input
|
||||
|
@ -63,14 +63,11 @@ class ChatDashboard(BaseChat):
|
||||
try:
|
||||
table_infos = await blocking_func_to_async(
|
||||
self._executor,
|
||||
client.get_similar_tables,
|
||||
client.get_db_summary,
|
||||
self.db_name,
|
||||
self.current_user_input,
|
||||
self.top_k,
|
||||
)
|
||||
# table_infos = client.get_similar_tables(
|
||||
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||
# )
|
||||
print("dashboard vector find tables:{}", table_infos)
|
||||
except Exception as e:
|
||||
print("db summary find error!" + str(e))
|
||||
|
@ -19,22 +19,14 @@ class ChatFactory(metaclass=Singleton):
|
||||
from dbgpt.app.scene.chat_dashboard.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.v1.chat import ChatKnowledge
|
||||
from dbgpt.app.scene.chat_knowledge.v1.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.inner_db_summary.chat import (
|
||||
InnerChatDBSummary,
|
||||
)
|
||||
from dbgpt.app.scene.chat_knowledge.inner_db_summary.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
|
||||
from dbgpt.app.scene.chat_knowledge.extract_triplet.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.extract_entity.chat import ExtractEntity
|
||||
from dbgpt.app.scene.chat_knowledge.extract_entity.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.summary.chat import ExtractSummary
|
||||
from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.refine_summary.chat import (
|
||||
ExtractRefineSummary,
|
||||
)
|
||||
from dbgpt.app.scene.chat_knowledge.refine_summary.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.rewrite.chat import QueryRewrite
|
||||
from dbgpt.app.scene.chat_knowledge.rewrite.prompt import prompt
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_learning.prompt import prompt
|
||||
|
@ -1,40 +0,0 @@
|
||||
from typing import Dict
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
|
||||
from dbgpt.util.tracer import trace
|
||||
|
||||
|
||||
class InnerChatDBSummary(BaseChat):
|
||||
chat_scene: str = ChatScene.InnerChatDBSummary.value()
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_session_id,
|
||||
user_input,
|
||||
db_select,
|
||||
db_summary,
|
||||
):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.InnerChatDBSummary,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=db_select,
|
||||
)
|
||||
|
||||
self.db_input = db_select
|
||||
self.db_summary = db_summary
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
input_values = {
|
||||
"db_input": self.db_input,
|
||||
"db_profile_summary": self.db_summary,
|
||||
}
|
||||
return input_values
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.InnerChatDBSummary.value
|
@ -1,17 +0,0 @@
|
||||
import logging
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
print("clean prompt response:", clean_str)
|
||||
return clean_str
|
||||
|
||||
def parse_view_response(self, ai_text, data) -> str:
|
||||
return ai_text
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
@ -1,45 +0,0 @@
|
||||
import json
|
||||
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import ChatScene
|
||||
|
||||
from dbgpt.app.scene.chat_knowledge.inner_db_summary.out_parser import (
|
||||
NormalChatOutputParser,
|
||||
)
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """"""
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
Based on the following known database information?, answer which tables are involved in the user input.
|
||||
Known database information:{db_profile_summary}
|
||||
Input:{db_input}
|
||||
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
|
||||
|
||||
|
||||
"""
|
||||
PROMPT_RESPONSE = """You must respond in JSON format as following format:
|
||||
{response}
|
||||
The response format must be JSON, and the key of JSON must be "table".
|
||||
"""
|
||||
|
||||
|
||||
RESPONSE_FORMAT = {"table": ["orders", "products"]}
|
||||
|
||||
|
||||
PROMPT_NEED_STREAM_OUT = False
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.InnerChatDBSummary.value(),
|
||||
input_variables=["db_profile_summary", "db_input", "response"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
@ -19,7 +19,6 @@ _DEFAULT_TEMPLATE_ZH = (
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.
|
||||
\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1.2.and 3.
|
||||
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
|
@ -1,32 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
|
||||
from dbgpt.app.scene.chat_knowledge.rewrite.prompt import prompt
|
||||
|
||||
|
||||
class QueryRewrite(BaseChat):
|
||||
chat_scene: str = ChatScene.QueryRewrite.value()
|
||||
|
||||
"""query rewrite by llm"""
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
chat_param["chat_mode"] = ChatScene.QueryRewrite
|
||||
super().__init__(
|
||||
chat_param=chat_param,
|
||||
)
|
||||
|
||||
self.nums = chat_param["select_param"]
|
||||
self.current_user_input = chat_param["current_user_input"]
|
||||
|
||||
async def generate_input_values(self):
|
||||
input_values = {
|
||||
"nums": self.nums,
|
||||
"original_query": self.current_user_input,
|
||||
}
|
||||
return input_values
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.QueryRewrite.value
|
@ -1,42 +0,0 @@
|
||||
import logging
|
||||
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueryRewriteParser(BaseOutputParser):
|
||||
def __init__(self, is_stream_out: bool, **kwargs):
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
|
||||
def parse_prompt_response(self, response, max_length: int = 128):
|
||||
lowercase = True
|
||||
try:
|
||||
results = []
|
||||
response = response.strip()
|
||||
|
||||
if response.startswith("queries:"):
|
||||
response = response[len("queries:") :]
|
||||
|
||||
queries = response.split(",")
|
||||
if len(queries) == 1:
|
||||
queries = response.split(",")
|
||||
if len(queries) == 1:
|
||||
queries = response.split("?")
|
||||
if len(queries) == 1:
|
||||
queries = response.split("?")
|
||||
for k in queries:
|
||||
rk = k
|
||||
if lowercase:
|
||||
rk = rk.lower()
|
||||
s = rk.strip()
|
||||
if s == "":
|
||||
continue
|
||||
results.append(s)
|
||||
except Exception as e:
|
||||
logger.error(f"parse query rewrite prompt_response error: {e}")
|
||||
return []
|
||||
return results
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
return data
|
@ -1,41 +0,0 @@
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from .out_parser import QueryRewriteParser
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
PROMPT_SCENE_DEFINE = """You are a helpful assistant that generates multiple search queries based on a single input query."""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """请根据原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:<queries>':
|
||||
"original_query:{original_query}\n"
|
||||
"queries:\n"
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'\n":
|
||||
"original query:: {original_query}\n"
|
||||
"queries:\n"
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
PROMPT_RESPONSE = """"""
|
||||
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.QueryRewrite.value(),
|
||||
input_variables=["nums", "original_query"],
|
||||
response_format=None,
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=QueryRewriteParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
@ -1,28 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
|
||||
|
||||
class ExtractSummary(BaseChat):
|
||||
chat_scene: str = ChatScene.ExtractSummary.value()
|
||||
|
||||
"""get summary by llm"""
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
chat_param["chat_mode"] = ChatScene.ExtractSummary
|
||||
super().__init__(
|
||||
chat_param=chat_param,
|
||||
)
|
||||
|
||||
self.user_input = chat_param["select_param"]
|
||||
|
||||
async def generate_input_values(self):
|
||||
input_values = {
|
||||
"context": self.user_input,
|
||||
}
|
||||
return input_values
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ExtractSummary.value
|
@ -1,28 +0,0 @@
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser, ResponseTye
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtractSummaryParser(BaseOutputParser):
|
||||
def __init__(self, is_stream_out: bool, **kwargs):
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
|
||||
def parse_prompt_response(
|
||||
self, response, max_length: int = 128
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
# clean_str = super().parse_prompt_response(response)
|
||||
print("clean prompt response:", response)
|
||||
return response
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
### tool out data to table view
|
||||
return data
|
||||
|
||||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str) -> str:
|
||||
try:
|
||||
return super().parse_model_nostream_resp(response, sep)
|
||||
except Exception as e:
|
||||
return str(e)
|
@ -1,46 +0,0 @@
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import ChatScene
|
||||
|
||||
from dbgpt.app.scene.chat_knowledge.summary.out_parser import ExtractSummaryParser
|
||||
|
||||
CFG = Config()
|
||||
|
||||
# PROMPT_SCENE_DEFINE = """You are an expert Q&A system that is trusted around the world.\nAlways answer the query using the provided context information, and not prior knowledge.\nSome rules to follow:\n1. Never directly reference the given context in your answer.\n2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines."""
|
||||
|
||||
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
|
||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions."""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结:
|
||||
{context}
|
||||
答案尽量精确和简单,不要过长,长度控制在100字左右
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
Write a quick summary of the following context:
|
||||
{context}
|
||||
the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters.
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
PROMPT_RESPONSE = """"""
|
||||
|
||||
|
||||
RESPONSE_FORMAT = """"""
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ExtractSummary.value(),
|
||||
input_variables=["context"],
|
||||
response_format=None,
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=ExtractSummaryParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
@ -5,6 +5,7 @@ from typing import Dict, List
|
||||
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.component import ComponentType
|
||||
|
||||
from dbgpt.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
@ -16,7 +17,9 @@ from dbgpt.app.knowledge.document_db import (
|
||||
KnowledgeDocumentEntity,
|
||||
)
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.util.tracer import trace
|
||||
|
||||
CFG = Config()
|
||||
@ -35,8 +38,7 @@ class ChatKnowledge(BaseChat):
|
||||
- model_name:(str) llm model name
|
||||
- select_param:(str) space name
|
||||
"""
|
||||
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
|
||||
self.knowledge_space = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatKnowledge
|
||||
@ -59,17 +61,37 @@ class ChatKnowledge(BaseChat):
|
||||
if self.space_context is None or self.space_context.get("prompt") is None
|
||||
else int(self.space_context["prompt"]["max_token"])
|
||||
)
|
||||
vector_store_config = {
|
||||
"vector_store_name": self.knowledge_space,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
}
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
self.knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
embedding_fn = embedding_factory.create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
config = VectorStoreConfig(name=self.knowledge_space, embedding_fn=embedding_fn)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
)
|
||||
query_rewrite = None
|
||||
self.worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
self.llm_client = DefaultLLMClient(worker_manager=self.worker_manager)
|
||||
if CFG.KNOWLEDGE_SEARCH_REWRITE:
|
||||
query_rewrite = QueryRewrite(
|
||||
llm_client=self.llm_client,
|
||||
model_name=self.llm_model,
|
||||
language=CFG.LANGUAGE,
|
||||
)
|
||||
self.embedding_retriever = EmbeddingRetriever(
|
||||
top_k=self.top_k,
|
||||
vector_store_connector=vector_store_connector,
|
||||
query_rewrite=query_rewrite,
|
||||
)
|
||||
self.prompt_template.template_is_strict = False
|
||||
self.relations = None
|
||||
@ -110,50 +132,31 @@ class ChatKnowledge(BaseChat):
|
||||
if self.space_context and self.space_context.get("prompt"):
|
||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||
from dbgpt.rag.retriever.reinforce import QueryReinforce
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
|
||||
# query reinforce, get similar queries
|
||||
query_reinforce = QueryReinforce(
|
||||
query=self.current_user_input, model_name=self.llm_model
|
||||
)
|
||||
queries = []
|
||||
if CFG.KNOWLEDGE_SEARCH_REWRITE:
|
||||
queries = await query_reinforce.rewrite()
|
||||
print("rewrite queries:", queries)
|
||||
queries.append(self.current_user_input)
|
||||
from dbgpt._private.chat_util import run_async_tasks
|
||||
|
||||
# similarity search from vector db
|
||||
tasks = [self.execute_similar_search(query) for query in queries]
|
||||
docs_with_scores = await run_async_tasks(tasks=tasks, concurrency_limit=1)
|
||||
candidates_with_scores = reduce(lambda x, y: x + y, docs_with_scores)
|
||||
# candidates document rerank
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker
|
||||
|
||||
ranker = DefaultRanker(self.top_k)
|
||||
candidates_with_scores = ranker.rank(candidates_with_scores)
|
||||
tasks = [self.execute_similar_search(self.current_user_input)]
|
||||
candidates_with_scores = await run_async_tasks(tasks=tasks, concurrency_limit=1)
|
||||
candidates_with_scores = reduce(lambda x, y: x + y, candidates_with_scores)
|
||||
self.chunks_with_score = []
|
||||
if not candidates_with_scores or len(candidates_with_scores) == 0:
|
||||
print("no relevant docs to retrieve")
|
||||
context = "no relevant docs to retrieve"
|
||||
else:
|
||||
self.chunks_with_score = []
|
||||
for d, score in candidates_with_scores:
|
||||
for chunk in candidates_with_scores:
|
||||
chucks = self.chunk_dao.get_document_chunks(
|
||||
query=DocumentChunkEntity(content=d.page_content),
|
||||
query=DocumentChunkEntity(content=chunk.content),
|
||||
document_ids=self.document_ids,
|
||||
)
|
||||
if len(chucks) > 0:
|
||||
self.chunks_with_score.append((chucks[0], score))
|
||||
self.chunks_with_score.append((chucks[0], chunk.score))
|
||||
|
||||
context = [doc.page_content for doc, _ in candidates_with_scores]
|
||||
|
||||
context = context[: self.max_token]
|
||||
context = "\n".join([doc.content for doc in candidates_with_scores])
|
||||
self.relations = list(
|
||||
set(
|
||||
[
|
||||
os.path.basename(str(d.metadata.get("source", "")))
|
||||
for d, _ in candidates_with_scores
|
||||
for d in candidates_with_scores
|
||||
]
|
||||
)
|
||||
)
|
||||
@ -201,7 +204,8 @@ class ChatKnowledge(BaseChat):
|
||||
references_list = list(references_dict.values())
|
||||
references_ele.set("references", json.dumps(references_list))
|
||||
html = ET.tostring(references_ele, encoding="utf-8")
|
||||
return html.decode("utf-8")
|
||||
reference = html.decode("utf-8")
|
||||
return reference.replace("\\n", "")
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
@ -213,10 +217,6 @@ class ChatKnowledge(BaseChat):
|
||||
|
||||
async def execute_similar_search(self, query):
|
||||
"""execute similarity search"""
|
||||
return await blocking_func_to_async(
|
||||
self._executor,
|
||||
self.knowledge_embedding_client.similar_search_with_scores,
|
||||
query,
|
||||
self.top_k,
|
||||
self.recall_score,
|
||||
return await self.embedding_retriever.aretrieve_with_scores(
|
||||
query, self.recall_score
|
||||
)
|
||||
|
@ -2,16 +2,20 @@ from typing import Dict, Optional, List
|
||||
from dataclasses import dataclass
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt.core.interface.message import OnceConversation
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
|
||||
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
|
||||
from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
# TODO move global config
|
||||
CFG = Config()
|
||||
@ -184,23 +188,14 @@ class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]):
|
||||
|
||||
async def map(self, input_value: ChatContext) -> ChatContext:
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
|
||||
# TODO, decompose the current operator into some atomic operators
|
||||
knowledge_space = input_value.select_param
|
||||
vector_store_config = {
|
||||
"vector_store_name": knowledge_space,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
}
|
||||
embedding_factory = self.system_app.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
|
||||
space_context = await self._get_space_context(knowledge_space)
|
||||
top_k = (
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
@ -219,16 +214,28 @@ class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]):
|
||||
]
|
||||
input_value.prompt_template.template = space_context["prompt"]["template"]
|
||||
|
||||
config = VectorStoreConfig(
|
||||
name=knowledge_space,
|
||||
embedding_fn=embedding_factory.create(
|
||||
EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
),
|
||||
)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
)
|
||||
embedding_retriever = EmbeddingRetriever(
|
||||
top_k=top_k, vector_store_connector=vector_store_connector
|
||||
)
|
||||
docs = await self.blocking_func_to_async(
|
||||
knowledge_embedding_client.similar_search,
|
||||
embedding_retriever.retrieve,
|
||||
input_value.current_user_input,
|
||||
top_k,
|
||||
)
|
||||
if not docs or len(docs) == 0:
|
||||
print("no relevant docs to retrieve")
|
||||
context = "no relevant docs to retrieve"
|
||||
else:
|
||||
context = [d.page_content for d in docs]
|
||||
context = [d.content for d in docs]
|
||||
context = context[:max_token]
|
||||
relations = list(
|
||||
set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs])
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
16
dbgpt/app/static/_next/static/chunks/113-15fc0b8bd2b5b9a1.js
Normal file
16
dbgpt/app/static/_next/static/chunks/113-15fc0b8bd2b5b9a1.js
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
11
dbgpt/app/static/_next/static/chunks/450-bd680f0e37e9b4b9.js
Normal file
11
dbgpt/app/static/_next/static/chunks/450-bd680f0e37e9b4b9.js
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
78
dbgpt/app/static/_next/static/chunks/607-2dedaf19149304c0.js
Normal file
78
dbgpt/app/static/_next/static/chunks/607-2dedaf19149304c0.js
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
16
dbgpt/app/static/_next/static/chunks/810-84757da754c6f3fc.js
Normal file
16
dbgpt/app/static/_next/static/chunks/810-84757da754c6f3fc.js
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
3
dbgpt/app/static/_next/static/css/9444ab27198bc61e.css
Normal file
3
dbgpt/app/static/_next/static/css/9444ab27198bc61e.css
Normal file
File diff suppressed because one or more lines are too long
@ -1 +0,0 @@
|
||||
self.__BUILD_MANIFEST=function(s,c,a,e,t,d,n,f,k,h,i,b){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":["static/chunks/29107295-90b90cb30c825230.js",s,c,a,d,n,f,k,"static/chunks/412-b911d4a677c64b70.js","static/chunks/981-ff77d5cc3ab95298.js","static/chunks/pages/index-d1740e3bc6dba7f5.js"],"/_error":["static/chunks/pages/_error-dee72aff9b2e2c12.js"],"/agent":[s,c,e,d,t,n,"static/chunks/pages/agent-92e9dce47267e88d.js"],"/chat":["static/chunks/pages/chat-84fbba4764166684.js"],"/chat/[scene]/[id]":["static/chunks/pages/chat/[scene]/[id]-f665336966e79cc9.js"],"/database":[s,c,a,e,t,f,h,"static/chunks/643-d8f53f40dd3c5b40.js","static/chunks/pages/database-3140f507fe61ccb8.js"],"/knowledge":[i,s,c,e,d,t,n,f,"static/chunks/551-266086fbfa0925ec.js","static/chunks/pages/knowledge-8ada4ce8fa909bf5.js"],"/knowledge/chunk":[e,t,"static/chunks/pages/knowledge/chunk-9f117a5ed799edd3.js"],"/models":[i,s,c,a,b,h,"static/chunks/pages/models-80218c46bc1d8cfa.js"],"/prompt":[s,c,a,b,"static/chunks/837-e6d4d1eb9e057050.js",k,"static/chunks/607-b224c640f6907e4b.js","static/chunks/pages/prompt-7f839dfd56bc4c20.js"],sortedPages:["/","/_app","/_error","/agent","/chat","/chat/[scene]/[id]","/database","/knowledge","/knowledge/chunk","/models","/prompt"]}}("static/chunks/64-91b49d45b9846775.js","static/chunks/479-b20198841f9a6a1e.js","static/chunks/9-bb2c54d5c06ba4bf.js","static/chunks/442-197e6cbc1e54109a.js","static/chunks/813-cce9482e33f2430c.js","static/chunks/553-df5701294eedae07.js","static/chunks/924-ba8e16df4d61ff5c.js","static/chunks/411-d9eba2657c72f766.js","static/chunks/270-2f094a936d056513.js","static/chunks/928-74244889bd7f2699.js","static/chunks/75fc9c18-a784766a129ec5fb.js","static/chunks/947-5980a3ff49069ddd.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();
|
@ -0,0 +1 @@
|
||||
self.__BUILD_MANIFEST=function(s,c,a,e,t,n,d,b,f,k,h,i,u,j){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":["static/chunks/29107295-90b90cb30c825230.js",s,c,e,a,d,b,f,h,"static/chunks/861-78929b4f98dbbfd6.js","static/chunks/161-96143606b49cf4a1.js","static/chunks/pages/index-a5e7e7433070d21b.js"],"/_error":["static/chunks/pages/_error-dee72aff9b2e2c12.js"],"/agent":[s,c,a,t,d,n,"static/chunks/pages/agent-a2599efbeb46e056.js"],"/chat":["static/chunks/pages/chat-47a20abbae16e858.js"],"/chat/[scene]/[id]":["static/chunks/pages/chat/[scene]/[id]-8df445f91cde33fa.js"],"/database":[s,c,e,a,t,n,f,k,"static/chunks/643-d8f53f40dd3c5b40.js","static/chunks/pages/database-d36f41810fc357a6.js"],"/knowledge":[i,s,c,a,t,d,n,f,u,k,h,"static/chunks/450-bd680f0e37e9b4b9.js","static/chunks/pages/knowledge-b9300e7addf1931f.js"],"/knowledge/chunk":[s,e,t,b,n,"static/chunks/pages/knowledge/chunk-652744b9d90c26c9.js"],"/models":[i,s,c,e,a,j,k,"static/chunks/pages/models-1145859ba0e2f20a.js"],"/prompt":[s,c,e,a,j,b,u,"static/chunks/346-b0aea1c99abd6f1e.js","static/chunks/607-2dedaf19149304c0.js","static/chunks/pages/prompt-fca5ed813d5018b1.js"],sortedPages:["/","/_app","/_error","/agent","/chat","/chat/[scene]/[id]","/database","/knowledge","/knowledge/chunk","/models","/prompt"]}}("static/chunks/113-15fc0b8bd2b5b9a1.js","static/chunks/17-d6c52cecd9ecc451.js","static/chunks/479-33b3ebe9be79a971.js","static/chunks/9-bb2c54d5c06ba4bf.js","static/chunks/442-197e6cbc1e54109a.js","static/chunks/813-cce9482e33f2430c.js","static/chunks/553-a89ad624ca0f1ffa.js","static/chunks/810-84757da754c6f3fc.js","static/chunks/411-b5d3e7f64bee2335.js","static/chunks/928-74244889bd7f2699.js","static/chunks/234-42f62dc360b2d9e4.js","static/chunks/75fc9c18-a784766a129ec5fb.js","static/chunks/45-9ff739c09925ea35.js","static/chunks/947-5980a3ff49069ddd.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
116
dbgpt/rag/chunk.py
Normal file
116
dbgpt/rag/chunk.py
Normal file
@ -0,0 +1,116 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""Document including document content, document metadata."""
|
||||
|
||||
content: str = (Field(default="", description="document text content"),)
|
||||
|
||||
metadata: Dict[str, Any] = (
|
||||
Field(
|
||||
default_factory=dict,
|
||||
description="metadata fields",
|
||||
),
|
||||
)
|
||||
|
||||
def set_content(self, content: str) -> None:
|
||||
"""Set the content"""
|
||||
self.content = content
|
||||
|
||||
def get_content(self) -> str:
|
||||
return self.content
|
||||
|
||||
@classmethod
|
||||
def langchain2doc(cls, document):
|
||||
"""Transformation from Langchain to Chunk Document format."""
|
||||
metadata = document.metadata or {}
|
||||
return cls(content=document.page_content, metadata=metadata)
|
||||
|
||||
@classmethod
|
||||
def doc2langchain(cls, chunk):
|
||||
"""Transformation from Chunk to Langchain Document format."""
|
||||
from langchain.schema import Document as LCDocument
|
||||
|
||||
return LCDocument(page_content=chunk.content, metadata=chunk.metadata)
|
||||
|
||||
|
||||
class Chunk(Document):
|
||||
"""
|
||||
Document Chunk including chunk content, chunk metadata, chunk summary, chunk relations.
|
||||
"""
|
||||
|
||||
chunk_id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="unique id for the chunk"
|
||||
)
|
||||
content: str = Field(default="", description="chunk text content")
|
||||
|
||||
metadata: Dict[str, Any] = (
|
||||
Field(
|
||||
default_factory=dict,
|
||||
description="metadata fields",
|
||||
),
|
||||
)
|
||||
score: float = Field(default=0.0, description="chunk text similarity score")
|
||||
summary: str = Field(default="", description="chunk text summary")
|
||||
separator: str = Field(
|
||||
default="\n",
|
||||
description="Separator between metadata fields when converting to string.",
|
||||
)
|
||||
|
||||
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
data = self.dict(**kwargs)
|
||||
data["class_name"] = self.class_name()
|
||||
return data
|
||||
|
||||
def to_json(self, **kwargs: Any) -> str:
|
||||
data = self.to_dict(**kwargs)
|
||||
return json.dumps(data)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.chunk_id,))
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.chunk_id == other.chunk_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any], **kwargs: Any): # type: ignore
|
||||
if isinstance(kwargs, dict):
|
||||
data.update(kwargs)
|
||||
|
||||
data.pop("class_name", None)
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data_str: str, **kwargs: Any): # type: ignore
|
||||
data = json.loads(data_str)
|
||||
return cls.from_dict(data, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def langchain2chunk(cls, document):
|
||||
"""Transformation from Langchain to Chunk Document format."""
|
||||
metadata = document.metadata or {}
|
||||
return cls(content=document.page_content, metadata=document.metadata)
|
||||
|
||||
@classmethod
|
||||
def llamaindex2chunk(cls, node):
|
||||
"""Transformation from LLama-Index to Chunk Document format."""
|
||||
metadata = node.metadata or {}
|
||||
return cls(content=node.content, metadata=metadata)
|
||||
|
||||
@classmethod
|
||||
def chunk2langchain(cls, chunk):
|
||||
"""Transformation from Chunk to Langchain Document format."""
|
||||
from langchain.schema import Document as LCDocument
|
||||
|
||||
return LCDocument(page_content=chunk.content, metadata=chunk.metadata)
|
||||
|
||||
@classmethod
|
||||
def chunk2llamaindex(cls, chunk):
|
||||
"""Transformation from Chunk to LLama-Index Document format."""
|
||||
from llama_index.schema import TextNode
|
||||
|
||||
return TextNode(text=chunk.content, metadata=chunk.metadata)
|
137
dbgpt/rag/chunk_manager.py
Normal file
137
dbgpt/rag/chunk_manager.py
Normal file
@ -0,0 +1,137 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.extractor.base import Extractor
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge
|
||||
|
||||
|
||||
class SplitterType(Enum):
|
||||
"""splitter type"""
|
||||
|
||||
LANGCHAIN = "langchain"
|
||||
LLAMA_INDEX = "llama-index"
|
||||
USER_DEFINE = "user_define"
|
||||
|
||||
|
||||
class ChunkParameters(BaseModel):
|
||||
"""ChunkParameters"""
|
||||
|
||||
chunk_strategy: str = Field(
|
||||
default=None,
|
||||
description="chunk strategy",
|
||||
)
|
||||
text_splitter: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="text splitter",
|
||||
)
|
||||
|
||||
splitter_type: SplitterType = Field(
|
||||
default=SplitterType.USER_DEFINE,
|
||||
description="splitter type",
|
||||
)
|
||||
|
||||
chunk_size: int = Field(
|
||||
default=512,
|
||||
description="chunk size",
|
||||
)
|
||||
chunk_overlap: int = Field(
|
||||
default=50,
|
||||
description="chunk overlap",
|
||||
)
|
||||
separator: str = Field(
|
||||
default="\n",
|
||||
description="chunk separator",
|
||||
)
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
"""ChunkManager"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Knowledge = None,
|
||||
chunk_parameter: Optional[ChunkParameters] = None,
|
||||
extractor: Optional[Extractor] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
chunk_parameter: (Optional[ChunkParameter]) Chunk parameter.
|
||||
extractor: (Optional[Extractor]) Extractor to use for summarization.
|
||||
"""
|
||||
self._knowledge = knowledge
|
||||
|
||||
self._extractor = extractor
|
||||
self._chunk_parameters = chunk_parameter or ChunkParameters()
|
||||
self._chunk_strategy = (
|
||||
chunk_parameter.chunk_strategy
|
||||
or self._knowledge.default_chunk_strategy().name
|
||||
)
|
||||
self._text_splitter = chunk_parameter.text_splitter
|
||||
self._splitter_type = chunk_parameter.splitter_type
|
||||
|
||||
def split(self, documents) -> List[Chunk]:
|
||||
"""Split a document into chunks."""
|
||||
text_splitter = self._select_text_splitter()
|
||||
if SplitterType.LANGCHAIN == self._splitter_type:
|
||||
documents = text_splitter.split_documents(documents)
|
||||
return [Chunk.langchain2chunk(document) for document in documents]
|
||||
elif SplitterType.LLAMA_INDEX == self._splitter_type:
|
||||
nodes = text_splitter.split_text(documents)
|
||||
return [Chunk.llamaindex2chunk(node) for node in nodes]
|
||||
else:
|
||||
return text_splitter.split_documents(documents)
|
||||
|
||||
def split_with_summary(
|
||||
self, document: Any, chunk_strategy: ChunkStrategy
|
||||
) -> List[Chunk]:
|
||||
"""Split a document into chunks and summary"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def chunk_parameters(self) -> ChunkParameters:
|
||||
return self._chunk_parameters
|
||||
|
||||
def set_text_splitter(
|
||||
self,
|
||||
text_splitter,
|
||||
splitter_type: Optional[SplitterType] = SplitterType.LANGCHAIN,
|
||||
) -> None:
|
||||
"""Add text splitter."""
|
||||
self._text_splitter = text_splitter
|
||||
self._splitter_type = splitter_type
|
||||
|
||||
def get_text_splitter(
|
||||
self,
|
||||
) -> Any:
|
||||
"""get text splitter."""
|
||||
return self._select_text_splitter()
|
||||
|
||||
def _select_text_splitter(
|
||||
self,
|
||||
):
|
||||
"""Select text splitter by chunk strategy."""
|
||||
if self._text_splitter:
|
||||
return self._text_splitter
|
||||
if not self._chunk_strategy or "Automatic" == self._chunk_strategy:
|
||||
self._chunk_strategy = self._knowledge.default_chunk_strategy().name
|
||||
if self._chunk_strategy not in [
|
||||
support_chunk_strategy.name
|
||||
for support_chunk_strategy in self._knowledge.support_chunk_strategy()
|
||||
]:
|
||||
current_type = self._knowledge.type().value
|
||||
if self._knowledge.document_type():
|
||||
current_type = self._knowledge.document_type().value
|
||||
raise ValueError(
|
||||
f"{current_type} knowledge not supported chunk strategy {self._chunk_strategy} "
|
||||
)
|
||||
strategy = ChunkStrategy[self._chunk_strategy]
|
||||
return strategy.match(
|
||||
chunk_size=self._chunk_parameters.chunk_size,
|
||||
chunk_overlap=self._chunk_parameters.chunk_overlap,
|
||||
separator=self._chunk_parameters.separator,
|
||||
)
|
@ -3,12 +3,15 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Type, TYPE_CHECKING
|
||||
|
||||
from dbgpt.component import BaseComponent
|
||||
from dbgpt.rag.embedding.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from dbgpt.rag.embedding.embeddings import Embeddings
|
||||
|
||||
|
||||
class EmbeddingFactory(BaseComponent, ABC):
|
||||
"""Abstract base class for EmbeddingFactory."""
|
||||
|
||||
name = "embedding_factory"
|
||||
|
||||
@abstractmethod
|
||||
@ -41,6 +44,4 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
if embedding_cls:
|
||||
return embedding_cls(**new_kwargs)
|
||||
else:
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
return HuggingFaceEmbeddings(**new_kwargs)
|
363
dbgpt/rag/embedding/embeddings.py
Normal file
363
dbgpt/rag/embedding/embeddings.py
Normal file
@ -0,0 +1,363 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import Field, Extra, BaseModel
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
||||
DEFAULT_BGE_MODEL = "BAAI/bge-large-en"
|
||||
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
|
||||
DEFAULT_QUERY_INSTRUCTION = (
|
||||
"Represent the question for retrieving supporting documents: "
|
||||
)
|
||||
DEFAULT_QUERY_BGE_INSTRUCTION_EN = (
|
||||
"Represent this question for searching relevant passages: "
|
||||
)
|
||||
DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:"
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models."""
|
||||
|
||||
"""refer to https://github.com/langchain-ai/langchain/tree/master/libs/langchain/langchain/embeddings"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_documents, texts
|
||||
)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_query, text
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
"""HuggingFace sentence_transformers embedding models.
|
||||
To use, you should have the ``sentence_transformers`` python package installed.
|
||||
Refer to https://github.com/langchain-ai/langchain/tree/master/libs/langchain/langchain/embeddings
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from .embeddings import HuggingFaceEmbeddings
|
||||
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
encode_kwargs = {'normalize_embeddings': False}
|
||||
hf = HuggingFaceEmbeddings(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = DEFAULT_MODEL_NAME
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the model."""
|
||||
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method of the model."""
|
||||
multi_process: bool = False
|
||||
"""Run encode() on multiple GPUs."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import sentence_transformers
|
||||
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
) from exc
|
||||
|
||||
self.client = sentence_transformers.SentenceTransformer(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
import sentence_transformers
|
||||
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
if self.multi_process:
|
||||
pool = self.client.start_multi_process_pool()
|
||||
embeddings = self.client.encode_multi_process(texts, pool)
|
||||
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
||||
else:
|
||||
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
||||
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
|
||||
class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around sentence_transformers embedding models.
|
||||
|
||||
To use, you should have the ``sentence_transformers``
|
||||
and ``InstructorEmbedding`` python packages installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
||||
|
||||
model_name = "hkunlp/instructor-large"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
encode_kwargs = {'normalize_embeddings': True}
|
||||
hf = HuggingFaceInstructEmbeddings(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = DEFAULT_INSTRUCT_MODEL
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the model."""
|
||||
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method of the model."""
|
||||
embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
|
||||
"""Instruction to use for embedding documents."""
|
||||
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
|
||||
"""Instruction to use for embedding query."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from InstructorEmbedding import INSTRUCTOR
|
||||
|
||||
self.client = INSTRUCTOR(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError("Dependencies for InstructorEmbedding not found.") from e
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
instruction_pairs = [[self.embed_instruction, text] for text in texts]
|
||||
embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
instruction_pair = [self.query_instruction, text]
|
||||
embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0]
|
||||
return embedding.tolist()
|
||||
|
||||
|
||||
class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
||||
"""HuggingFace BGE sentence_transformers embedding models.
|
||||
|
||||
To use, you should have the ``sentence_transformers`` python package installed.
|
||||
refer to https://github.com/langchain-ai/langchain/tree/master/libs/langchain/langchain/embeddings
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
|
||||
model_name = "BAAI/bge-large-en"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
encode_kwargs = {'normalize_embeddings': True}
|
||||
hf = HuggingFaceBgeEmbeddings(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = DEFAULT_BGE_MODEL
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the model."""
|
||||
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method of the model."""
|
||||
query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION_EN
|
||||
"""Instruction to use for embedding query."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import sentence_transformers
|
||||
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Please install it with `pip install sentence_transformers`."
|
||||
) from exc
|
||||
|
||||
self.client = sentence_transformers.SentenceTransformer(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
)
|
||||
if "-zh" in self.model_name:
|
||||
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
texts = [t.replace("\n", " ") for t in texts]
|
||||
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
||||
return embeddings.tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
embedding = self.client.encode(
|
||||
self.query_instruction + text, **self.encode_kwargs
|
||||
)
|
||||
return embedding.tolist()
|
||||
|
||||
|
||||
class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
|
||||
"""Embed texts using the HuggingFace API.
|
||||
|
||||
Requires a HuggingFace Inference API key and a model name.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
"""Your API key for the HuggingFace Inference API."""
|
||||
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
"""The name of the model to use for text embeddings."""
|
||||
|
||||
@property
|
||||
def _api_url(self) -> str:
|
||||
return (
|
||||
"https://api-inference.huggingface.co"
|
||||
"/pipeline"
|
||||
"/feature-extraction"
|
||||
f"/{self.model_name}"
|
||||
)
|
||||
|
||||
@property
|
||||
def _headers(self) -> dict:
|
||||
return {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get the embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts (Documents): A list of texts to get embeddings for.
|
||||
|
||||
Returns:
|
||||
Embedded texts as List[List[float]], where each inner List[float]
|
||||
corresponds to a single input text.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
|
||||
|
||||
hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
|
||||
api_key="your_api_key",
|
||||
model_name="sentence-transformers/all-MiniLM-l6-v2"
|
||||
)
|
||||
texts = ["Hello, world!", "How are you?"]
|
||||
hf_embeddings.embed_documents(texts)
|
||||
"""
|
||||
response = requests.post(
|
||||
self._api_url,
|
||||
headers=self._headers,
|
||||
json={
|
||||
"inputs": texts,
|
||||
"options": {"wait_for_model": True, "use_cache": True},
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
@ -1,12 +0,0 @@
|
||||
from dbgpt.rag.embedding_engine.source_embedding import SourceEmbedding, register
|
||||
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from dbgpt.rag.embedding_engine.knowledge_type import KnowledgeType
|
||||
from dbgpt.rag.embedding_engine.pre_text_splitter import PreTextSplitter
|
||||
|
||||
__all__ = [
|
||||
"SourceEmbedding",
|
||||
"register",
|
||||
"EmbeddingEngine",
|
||||
"KnowledgeType",
|
||||
"PreTextSplitter",
|
||||
]
|
@ -1,64 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import (
|
||||
TextSplitter,
|
||||
SpacyTextSplitter,
|
||||
RecursiveCharacterTextSplitter,
|
||||
)
|
||||
|
||||
from dbgpt.rag.embedding_engine import SourceEmbedding, register
|
||||
from dbgpt.rag.embedding_engine.loader.csv_loader import NewCSVLoader
|
||||
|
||||
|
||||
class CSVEmbedding(SourceEmbedding):
|
||||
"""csv embedding for read csv document."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
vector_store_config,
|
||||
source_reader: Optional = None,
|
||||
text_splitter: Optional[TextSplitter] = None,
|
||||
):
|
||||
"""Initialize with csv path.
|
||||
Args:
|
||||
- file_path: data source path
|
||||
- vector_store_config: vector store config params.
|
||||
- source_reader: Optional[BaseLoader]
|
||||
- text_splitter: Optional[TextSplitter]
|
||||
"""
|
||||
super().__init__(
|
||||
file_path, vector_store_config, source_reader=None, text_splitter=None
|
||||
)
|
||||
self.file_path = file_path
|
||||
self.vector_store_config = vector_store_config
|
||||
self.source_reader = source_reader or None
|
||||
self.text_splitter = text_splitter or None
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from csv path."""
|
||||
if self.source_reader is None:
|
||||
self.source_reader = NewCSVLoader(self.file_path)
|
||||
if self.text_splitter is None:
|
||||
try:
|
||||
self.text_splitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=100,
|
||||
chunk_overlap=100,
|
||||
)
|
||||
except Exception:
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=100, chunk_overlap=50
|
||||
)
|
||||
|
||||
return self.source_reader.load_and_split(self.text_splitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
i = 0
|
||||
for d in documents:
|
||||
documents[i].page_content = d.page_content.replace("\n", "")
|
||||
i += 1
|
||||
return documents
|
@ -1,145 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import (
|
||||
EmbeddingFactory,
|
||||
DefaultEmbeddingFactory,
|
||||
)
|
||||
from dbgpt.rag.embedding_engine.knowledge_type import (
|
||||
get_knowledge_embedding,
|
||||
KnowledgeType,
|
||||
)
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class EmbeddingEngine:
|
||||
"""EmbeddingEngine provide a chain process include(read->text_split->data_process->index_store) for knowledge document embedding into vector store.
|
||||
1.knowledge_embedding:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
|
||||
2.similar_search: similarity search from vector_store
|
||||
how to use reference:https://db-gpt.readthedocs.io/en/latest/modules/knowledge.html
|
||||
how to integrate:https://db-gpt.readthedocs.io/en/latest/modules/knowledge/pdf/pdf_embedding.html
|
||||
Example:
|
||||
.. code-block:: python
|
||||
embedding_model = "your_embedding_model"
|
||||
vector_store_type = "Chroma"
|
||||
chroma_persist_path = "your_persist_path"
|
||||
vector_store_config = {
|
||||
"vector_store_name": "document_test",
|
||||
"vector_store_type": vector_store_type,
|
||||
"chroma_persist_path": chroma_persist_path,
|
||||
}
|
||||
|
||||
# it can be .md,.pdf,.docx, .csv, .html
|
||||
document_path = "your_path/test.md"
|
||||
embedding_engine = EmbeddingEngine(
|
||||
knowledge_source=document_path,
|
||||
knowledge_type=KnowledgeType.DOCUMENT.value,
|
||||
model_name=embedding_model,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
# embedding document content to vector store
|
||||
embedding_engine.knowledge_embedding()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
vector_store_config,
|
||||
knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value,
|
||||
knowledge_source: Optional[str] = None,
|
||||
source_reader: Optional = None,
|
||||
text_splitter: Optional[TextSplitter] = None,
|
||||
embedding_factory: EmbeddingFactory = None,
|
||||
):
|
||||
"""Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source
|
||||
Args:
|
||||
- model_name: model_name
|
||||
- vector_store_config: vector store config: Dict
|
||||
- knowledge_type: Optional[KnowledgeType]
|
||||
- knowledge_source: Optional[str]
|
||||
- source_reader: Optional[BaseLoader]
|
||||
- text_splitter: Optional[TextSplitter]
|
||||
- embedding_factory: EmbeddingFactory
|
||||
"""
|
||||
self.knowledge_source = knowledge_source
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.knowledge_type = knowledge_type
|
||||
if not embedding_factory:
|
||||
embedding_factory = DefaultEmbeddingFactory()
|
||||
self.embeddings = embedding_factory.create(model_name=self.model_name)
|
||||
self.vector_store_config["embeddings"] = self.embeddings
|
||||
self.source_reader = source_reader
|
||||
self.text_splitter = text_splitter
|
||||
|
||||
def knowledge_embedding(self):
|
||||
"""source embedding is chain process.read->text_split->data_process->index_store"""
|
||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||
self.knowledge_embedding_client.source_embedding()
|
||||
|
||||
def knowledge_embedding_batch(self, docs):
|
||||
"""Deprecation"""
|
||||
# docs = self.knowledge_embedding_client.read_batch()
|
||||
return self.knowledge_embedding_client.index_to_store(docs)
|
||||
|
||||
def read(self):
|
||||
"""Deprecation"""
|
||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||
return self.knowledge_embedding_client.read_batch()
|
||||
|
||||
def init_knowledge_embedding(self):
|
||||
return get_knowledge_embedding(
|
||||
self.knowledge_type,
|
||||
self.knowledge_source,
|
||||
self.vector_store_config,
|
||||
self.source_reader,
|
||||
self.text_splitter,
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
"""vector db similar search in vector database.
|
||||
Return topk docs.
|
||||
Args:
|
||||
- text: query text
|
||||
- topk: top k
|
||||
"""
|
||||
vector_client = VectorStoreConnector(
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
# https://github.com/chroma-core/chroma/issues/657
|
||||
ans = vector_client.similar_search(text, topk)
|
||||
return ans
|
||||
|
||||
def similar_search_with_scores(self, text, topk, score_threshold: float = 0.3):
|
||||
"""
|
||||
similar_search_with_score in vector database.
|
||||
Return docs and relevance scores in the range [0, 1].
|
||||
Args:
|
||||
doc(str): query text
|
||||
topk(int): return docs nums. Defaults to 4.
|
||||
score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
|
||||
"""
|
||||
vector_client = VectorStoreConnector(
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
ans = vector_client.similar_search_with_scores(text, topk, score_threshold)
|
||||
return ans
|
||||
|
||||
def vector_exist(self):
|
||||
"""vector db is exist"""
|
||||
vector_client = VectorStoreConnector(
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
return vector_client.vector_name_exists()
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
"""delete vector db by ids
|
||||
Args:
|
||||
- ids: vector ids
|
||||
"""
|
||||
vector_client = VectorStoreConnector(
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
vector_client.delete_by_ids(ids=ids)
|
@ -1,26 +0,0 @@
|
||||
from typing import List, Optional
|
||||
import chardet
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
class EncodeTextLoader(BaseLoader):
|
||||
"""Load text files."""
|
||||
|
||||
def __init__(self, file_path: str, encoding: Optional[str] = None):
|
||||
"""Initialize with file path."""
|
||||
self.file_path = file_path
|
||||
self.encoding = encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from file path."""
|
||||
with open(self.file_path, "rb") as f:
|
||||
raw_text = f.read()
|
||||
result = chardet.detect(raw_text)
|
||||
if result["encoding"] is None:
|
||||
text = raw_text.decode("utf-8")
|
||||
else:
|
||||
text = raw_text.decode(result["encoding"])
|
||||
metadata = {"source": self.file_path}
|
||||
return [Document(page_content=text, metadata=metadata)]
|
@ -1,107 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
from dbgpt.rag.embedding_engine.csv_embedding import CSVEmbedding
|
||||
from dbgpt.rag.embedding_engine.markdown_embedding import MarkdownEmbedding
|
||||
from dbgpt.rag.embedding_engine.pdf_embedding import PDFEmbedding
|
||||
from dbgpt.rag.embedding_engine.ppt_embedding import PPTEmbedding
|
||||
from dbgpt.rag.embedding_engine.string_embedding import StringEmbedding
|
||||
from dbgpt.rag.embedding_engine.url_embedding import URLEmbedding
|
||||
from dbgpt.rag.embedding_engine.word_embedding import WordEmbedding
|
||||
|
||||
DocumentEmbeddingType = {
|
||||
".txt": (MarkdownEmbedding, {}),
|
||||
".md": (MarkdownEmbedding, {}),
|
||||
".html": (MarkdownEmbedding, {}),
|
||||
".pdf": (PDFEmbedding, {}),
|
||||
".doc": (WordEmbedding, {}),
|
||||
".docx": (WordEmbedding, {}),
|
||||
".csv": (CSVEmbedding, {}),
|
||||
".ppt": (PPTEmbedding, {}),
|
||||
".pptx": (PPTEmbedding, {}),
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeType(Enum):
|
||||
DOCUMENT = "DOCUMENT"
|
||||
URL = "URL"
|
||||
TEXT = "TEXT"
|
||||
OSS = "OSS"
|
||||
S3 = "S3"
|
||||
NOTION = "NOTION"
|
||||
MYSQL = "MYSQL"
|
||||
TIDB = "TIDB"
|
||||
CLICKHOUSE = "CLICKHOUSE"
|
||||
OCEANBASE = "OCEANBASE"
|
||||
ELASTICSEARCH = "ELASTICSEARCH"
|
||||
HIVE = "HIVE"
|
||||
PRESTO = "PRESTO"
|
||||
KAFKA = "KAFKA"
|
||||
SPARK = "SPARK"
|
||||
YOUTUBE = "YOUTUBE"
|
||||
|
||||
|
||||
def get_knowledge_embedding(
|
||||
knowledge_type,
|
||||
knowledge_source,
|
||||
vector_store_config=None,
|
||||
source_reader=None,
|
||||
text_splitter=None,
|
||||
):
|
||||
match knowledge_type:
|
||||
case KnowledgeType.DOCUMENT.value:
|
||||
extension = "." + knowledge_source.rsplit(".", 1)[-1]
|
||||
if extension in DocumentEmbeddingType:
|
||||
knowledge_class, knowledge_args = DocumentEmbeddingType[extension]
|
||||
embedding = knowledge_class(
|
||||
knowledge_source,
|
||||
vector_store_config=vector_store_config,
|
||||
source_reader=source_reader,
|
||||
text_splitter=text_splitter,
|
||||
**knowledge_args,
|
||||
)
|
||||
return embedding
|
||||
raise ValueError(f"Unsupported knowledge document type '{extension}'")
|
||||
case KnowledgeType.URL.value:
|
||||
embedding = URLEmbedding(
|
||||
file_path=knowledge_source,
|
||||
vector_store_config=vector_store_config,
|
||||
source_reader=source_reader,
|
||||
text_splitter=text_splitter,
|
||||
)
|
||||
return embedding
|
||||
case KnowledgeType.TEXT.value:
|
||||
embedding = StringEmbedding(
|
||||
file_path=knowledge_source,
|
||||
vector_store_config=vector_store_config,
|
||||
source_reader=source_reader,
|
||||
text_splitter=text_splitter,
|
||||
)
|
||||
return embedding
|
||||
case KnowledgeType.OSS.value:
|
||||
raise Exception("OSS have not integrate")
|
||||
case KnowledgeType.S3.value:
|
||||
raise Exception("S3 have not integrate")
|
||||
case KnowledgeType.NOTION.value:
|
||||
raise Exception("NOTION have not integrate")
|
||||
case KnowledgeType.MYSQL.value:
|
||||
raise Exception("MYSQL have not integrate")
|
||||
case KnowledgeType.TIDB.value:
|
||||
raise Exception("TIDB have not integrate")
|
||||
case KnowledgeType.CLICKHOUSE.value:
|
||||
raise Exception("CLICKHOUSE have not integrate")
|
||||
case KnowledgeType.OCEANBASE.value:
|
||||
raise Exception("OCEANBASE have not integrate")
|
||||
case KnowledgeType.ELASTICSEARCH.value:
|
||||
raise Exception("ELASTICSEARCH have not integrate")
|
||||
case KnowledgeType.HIVE.value:
|
||||
raise Exception("HIVE have not integrate")
|
||||
case KnowledgeType.PRESTO.value:
|
||||
raise Exception("PRESTO have not integrate")
|
||||
case KnowledgeType.KAFKA.value:
|
||||
raise Exception("KAFKA have not integrate")
|
||||
case KnowledgeType.SPARK.value:
|
||||
raise Exception("SPARK have not integrate")
|
||||
case KnowledgeType.YOUTUBE.value:
|
||||
raise Exception("YOUTUBE have not integrate")
|
||||
case _:
|
||||
raise Exception("unknown knowledge type")
|
@ -1,55 +0,0 @@
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
|
||||
|
||||
class CHNDocumentSplitter(CharacterTextSplitter):
|
||||
def __init__(self, pdf: bool = False, sentence_size: int = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pdf = pdf
|
||||
self.sentence_size = sentence_size
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
if self.pdf:
|
||||
text = re.sub(r"\n{3,}", r"\n", text)
|
||||
text = re.sub("\s", " ", text)
|
||||
text = re.sub("\n\n", "", text)
|
||||
|
||||
text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text)
|
||||
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)
|
||||
text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text)
|
||||
text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text)
|
||||
text = text.rstrip()
|
||||
ls = [i for i in text.split("\n") if i]
|
||||
for ele in ls:
|
||||
if len(ele) > self.sentence_size:
|
||||
ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele)
|
||||
ele1_ls = ele1.split("\n")
|
||||
for ele_ele1 in ele1_ls:
|
||||
if len(ele_ele1) > self.sentence_size:
|
||||
ele_ele2 = re.sub(
|
||||
r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r"\1\n\2", ele_ele1
|
||||
)
|
||||
ele2_ls = ele_ele2.split("\n")
|
||||
for ele_ele2 in ele2_ls:
|
||||
if len(ele_ele2) > self.sentence_size:
|
||||
ele_ele3 = re.sub(
|
||||
'( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2
|
||||
)
|
||||
ele2_id = ele2_ls.index(ele_ele2)
|
||||
ele2_ls = (
|
||||
ele2_ls[:ele2_id]
|
||||
+ [i for i in ele_ele3.split("\n") if i]
|
||||
+ ele2_ls[ele2_id + 1 :]
|
||||
)
|
||||
ele_id = ele1_ls.index(ele_ele1)
|
||||
ele1_ls = (
|
||||
ele1_ls[:ele_id]
|
||||
+ [i for i in ele2_ls if i]
|
||||
+ ele1_ls[ele_id + 1 :]
|
||||
)
|
||||
|
||||
id = ls.index(ele)
|
||||
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1 :]
|
||||
return ls
|
@ -1,76 +0,0 @@
|
||||
"""Loads a CSV file into a list of documents.
|
||||
|
||||
Each document represents one row of the CSV file. Every row is converted into a
|
||||
key/value pair and outputted to a new line in the document's page_content.
|
||||
|
||||
The source for each document loaded from csv is set to the value of the
|
||||
`file_path` argument for all documents by default.
|
||||
You can override this by setting the `source_column` argument to the
|
||||
name of a column in the CSV file.
|
||||
The source of each document will then be set to the value of the column
|
||||
with the name specified in `source_column`.
|
||||
|
||||
Output Example:
|
||||
.. code-block:: txt
|
||||
|
||||
column1: value1
|
||||
column2: value2
|
||||
column3: value3
|
||||
"""
|
||||
from typing import Optional, Dict, List
|
||||
import csv
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
class NewCSVLoader(BaseLoader):
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[Dict] = None,
|
||||
encoding: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
file_path: The path to the CSV file.
|
||||
source_column: The name of the column in the CSV file to use as the source.
|
||||
Optional. Defaults to None.
|
||||
csv_args: A dictionary of arguments to pass to the csv.DictReader.
|
||||
Optional. Defaults to None.
|
||||
encoding: The encoding of the CSV file. Optional. Defaults to None.
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.source_column = source_column
|
||||
self.encoding = encoding
|
||||
self.csv_args = csv_args or {}
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load data into document objects."""
|
||||
|
||||
docs = []
|
||||
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv_reader):
|
||||
strs = []
|
||||
for k, v in row.items():
|
||||
if k is None or v is None:
|
||||
continue
|
||||
strs.append(f"{k.strip()}: {v.strip()}")
|
||||
content = "\n".join(strs)
|
||||
try:
|
||||
source = (
|
||||
row[self.source_column]
|
||||
if self.source_column is not None
|
||||
else self.file_path
|
||||
)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Source column '{self.source_column}' not found in CSV file."
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
@ -1,28 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
import docx
|
||||
|
||||
|
||||
class DocxLoader(BaseLoader):
|
||||
"""Load docx files."""
|
||||
|
||||
def __init__(self, file_path: str, encoding: Optional[str] = None):
|
||||
"""Initialize with file path."""
|
||||
self.file_path = file_path
|
||||
self.encoding = encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from file path."""
|
||||
docs = []
|
||||
doc = docx.Document(self.file_path)
|
||||
content = []
|
||||
for i in range(len(doc.paragraphs)):
|
||||
para = doc.paragraphs[i]
|
||||
text = para.text
|
||||
content.append(text)
|
||||
docs.append(
|
||||
Document(page_content="".join(content), metadata={"source": self.file_path})
|
||||
)
|
||||
return docs
|
@ -1,55 +0,0 @@
|
||||
"""Loader that loads image files."""
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import fitz
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
|
||||
class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
|
||||
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
|
||||
|
||||
def _get_elements(self) -> List:
|
||||
def pdf_ocr_txt(filepath, dir_path="tmp_files"):
|
||||
full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
|
||||
if not os.path.exists(full_dir_path):
|
||||
os.makedirs(full_dir_path)
|
||||
filename = os.path.split(filepath)[-1]
|
||||
ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False)
|
||||
doc = fitz.open(filepath)
|
||||
txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename))
|
||||
img_name = os.path.join(full_dir_path, ".tmp.png")
|
||||
with open(txt_file_path, "w", encoding="utf-8") as fout:
|
||||
for i in range(doc.page_count):
|
||||
page = doc[i]
|
||||
text = page.get_text("")
|
||||
fout.write(text)
|
||||
fout.write("\n")
|
||||
|
||||
img_list = page.get_images()
|
||||
for img in img_list:
|
||||
pix = fitz.Pixmap(doc, img[0])
|
||||
|
||||
pix.save(img_name)
|
||||
|
||||
result = ocr.ocr(img_name)
|
||||
ocr_result = [i[1][0] for line in result for i in line]
|
||||
fout.write("\n".join(ocr_result))
|
||||
os.remove(img_name)
|
||||
return txt_file_path
|
||||
|
||||
txt_file_path = pdf_ocr_txt(self.file_path)
|
||||
from unstructured.partition.text import partition_text
|
||||
|
||||
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
filepath = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test_py.pdf"
|
||||
)
|
||||
loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
|
||||
docs = loader.load()
|
||||
for doc in docs:
|
||||
print(doc)
|
@ -1,28 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from pptx import Presentation
|
||||
|
||||
|
||||
class PPTLoader(BaseLoader):
|
||||
"""Load PPT files."""
|
||||
|
||||
def __init__(self, file_path: str, encoding: Optional[str] = None):
|
||||
"""Initialize with file path."""
|
||||
self.file_path = file_path
|
||||
self.encoding = encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from file path."""
|
||||
pr = Presentation(self.file_path)
|
||||
docs = []
|
||||
for slide in pr.slides:
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text") and shape.text:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=shape.text, metadata={"source": slide.slide_id}
|
||||
)
|
||||
)
|
||||
return docs
|
@ -1,68 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List, Optional
|
||||
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import (
|
||||
SpacyTextSplitter,
|
||||
CharacterTextSplitter,
|
||||
RecursiveCharacterTextSplitter,
|
||||
TextSplitter,
|
||||
)
|
||||
|
||||
from dbgpt.rag.embedding_engine import SourceEmbedding, register
|
||||
from dbgpt.rag.embedding_engine.encode_text_loader import EncodeTextLoader
|
||||
|
||||
|
||||
class MarkdownEmbedding(SourceEmbedding):
|
||||
"""markdown embedding for read markdown document."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
vector_store_config,
|
||||
source_reader: Optional = None,
|
||||
text_splitter: Optional[TextSplitter] = None,
|
||||
):
|
||||
"""Initialize raw text word path."""
|
||||
super().__init__(
|
||||
file_path, vector_store_config, source_reader=None, text_splitter=None
|
||||
)
|
||||
self.file_path = file_path
|
||||
self.vector_store_config = vector_store_config
|
||||
self.source_reader = source_reader or None
|
||||
self.text_splitter = text_splitter or None
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from markdown path."""
|
||||
if self.source_reader is None:
|
||||
self.source_reader = EncodeTextLoader(self.file_path)
|
||||
if self.text_splitter is None:
|
||||
try:
|
||||
self.text_splitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=100,
|
||||
chunk_overlap=100,
|
||||
)
|
||||
except Exception:
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=100, chunk_overlap=50
|
||||
)
|
||||
|
||||
return self.source_reader.load_and_split(self.text_splitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
i = 0
|
||||
for d in documents:
|
||||
content = markdown.markdown(d.page_content)
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
for tag in soup(["!doctype", "meta", "i.fa"]):
|
||||
tag.extract()
|
||||
documents[i].page_content = soup.get_text()
|
||||
documents[i].page_content = documents[i].page_content.replace("\n", " ")
|
||||
i += 1
|
||||
return documents
|
@ -1,66 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import (
|
||||
SpacyTextSplitter,
|
||||
RecursiveCharacterTextSplitter,
|
||||
TextSplitter,
|
||||
)
|
||||
|
||||
from dbgpt.rag.embedding_engine import SourceEmbedding, register
|
||||
|
||||
|
||||
class PDFEmbedding(SourceEmbedding):
|
||||
"""pdf embedding for read pdf document."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
vector_store_config,
|
||||
source_reader: Optional = None,
|
||||
text_splitter: Optional[TextSplitter] = None,
|
||||
):
|
||||
"""Initialize pdf word path.
|
||||
Args:
|
||||
- file_path: data source path
|
||||
- vector_store_config: vector store config params.
|
||||
- source_reader: Optional[BaseLoader]
|
||||
- text_splitter: Optional[TextSplitter]
|
||||
"""
|
||||
super().__init__(
|
||||
file_path, vector_store_config, source_reader=None, text_splitter=None
|
||||
)
|
||||
self.file_path = file_path
|
||||
self.vector_store_config = vector_store_config
|
||||
self.source_reader = source_reader or None
|
||||
self.text_splitter = text_splitter or None
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from pdf path."""
|
||||
if self.source_reader is None:
|
||||
self.source_reader = PyPDFLoader(self.file_path)
|
||||
if self.text_splitter is None:
|
||||
try:
|
||||
self.text_splitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=100,
|
||||
chunk_overlap=100,
|
||||
)
|
||||
except Exception:
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=100, chunk_overlap=50
|
||||
)
|
||||
|
||||
return self.source_reader.load_and_split(self.text_splitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
i = 0
|
||||
for d in documents:
|
||||
documents[i].page_content = d.page_content.replace("\n", "")
|
||||
i += 1
|
||||
return documents
|
@ -1,66 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import (
|
||||
SpacyTextSplitter,
|
||||
RecursiveCharacterTextSplitter,
|
||||
TextSplitter,
|
||||
)
|
||||
|
||||
from dbgpt.rag.embedding_engine import SourceEmbedding, register
|
||||
from dbgpt.rag.embedding_engine.loader.ppt_loader import PPTLoader
|
||||
|
||||
|
||||
class PPTEmbedding(SourceEmbedding):
|
||||
"""ppt embedding for read ppt document."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
vector_store_config,
|
||||
source_reader: Optional = None,
|
||||
text_splitter: Optional[TextSplitter] = None,
|
||||
):
|
||||
"""Initialize ppt word path.
|
||||
Args:
|
||||
- file_path: data source path
|
||||
- vector_store_config: vector store config params.
|
||||
- source_reader: Optional[BaseLoader]
|
||||
- text_splitter: Optional[TextSplitter]
|
||||
"""
|
||||
super().__init__(
|
||||
file_path, vector_store_config, source_reader=None, text_splitter=None
|
||||
)
|
||||
self.file_path = file_path
|
||||
self.vector_store_config = vector_store_config
|
||||
self.source_reader = source_reader or None
|
||||
self.text_splitter = text_splitter or None
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from ppt path."""
|
||||
if self.source_reader is None:
|
||||
self.source_reader = PPTLoader(self.file_path)
|
||||
if self.text_splitter is None:
|
||||
try:
|
||||
self.text_splitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=100,
|
||||
chunk_overlap=100,
|
||||
)
|
||||
except Exception:
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=100, chunk_overlap=50
|
||||
)
|
||||
|
||||
return self.source_reader.load_and_split(self.text_splitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
i = 0
|
||||
for d in documents:
|
||||
documents[i].page_content = d.page_content.replace("\n", "")
|
||||
i += 1
|
||||
return documents
|
@ -1,61 +0,0 @@
|
||||
# from langchain.embeddings import HuggingFaceEmbeddings
|
||||
# from langchain.vectorstores import Milvus
|
||||
# from pymilvus import Collection,utility
|
||||
# from pymilvus import datasource, DataType, FieldSchema, CollectionSchema
|
||||
#
|
||||
# # milvus = datasource.connect(
|
||||
# # alias="default",
|
||||
# # host='localhost',
|
||||
# # port="19530"
|
||||
# # )
|
||||
# # collection = Collection("book")
|
||||
#
|
||||
#
|
||||
# # Get an existing collection.
|
||||
# # collection.load()
|
||||
# #
|
||||
# # search_params = {"metric_type": "L2", "params": {}, "offset": 5}
|
||||
# #
|
||||
# # results = collection.search(
|
||||
# # data=[[0.1, 0.2]],
|
||||
# # anns_field="book_intro",
|
||||
# # param=search_params,
|
||||
# # limit=10,
|
||||
# # expr=None,
|
||||
# # output_fields=['book_id'],
|
||||
# # consistency_level="Strong"
|
||||
# # )
|
||||
# #
|
||||
# # # get the IDs of all returned hits
|
||||
# # results[0].ids
|
||||
# #
|
||||
# # # get the distances to the query vector from all returned hits
|
||||
# # results[0].distances
|
||||
# #
|
||||
# # # get the value of an output field specified in the search request.
|
||||
# # # vector fields are not supported yet.
|
||||
# # hit = results[0][0]
|
||||
# # hit.entity.get('title')
|
||||
#
|
||||
# # milvus = datasource.connect(
|
||||
# # alias="default",
|
||||
# # host='localhost',
|
||||
# # port="19530"
|
||||
# # )
|
||||
# from dbgpt.vector_store.milvus_store import MilvusStore
|
||||
#
|
||||
# data = ["aaa", "bbb"]
|
||||
# model_name = "xx/all-MiniLM-L6-v2"
|
||||
# embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
||||
#
|
||||
# # text_embeddings = Text2Vectors()
|
||||
# mivuls = MilvusStore(cfg={"url": "127.0.0.1", "port": "19530", "alias": "default", "table_name": "test_k"})
|
||||
#
|
||||
# mivuls.insert(["textc","tezt2"])
|
||||
# print("success")
|
||||
# ct
|
||||
# # mivuls.from_texts(texts=data, embedding=embeddings)
|
||||
# # docs,
|
||||
# # embedding=embeddings,
|
||||
# # connection_args={"host": "127.0.0.1", "port": "19530", "alias": "default"}
|
||||
# # )
|
@ -1,126 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
registered_methods = []
|
||||
|
||||
|
||||
def register(method):
|
||||
registered_methods.append(method.__name__)
|
||||
return method
|
||||
|
||||
|
||||
class SourceEmbedding(ABC):
|
||||
"""base class for read data source embedding pipeline.
|
||||
include data read, data process, data split, data to vector, data index vector store
|
||||
Implementations should implement the method
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
vector_store_config: {},
|
||||
source_reader: Optional = None,
|
||||
text_splitter: Optional[TextSplitter] = None,
|
||||
embedding_args: Optional[Dict] = None,
|
||||
):
|
||||
"""Initialize with Loader url, model_name, vector_store_config
|
||||
Args:
|
||||
- file_path: data source path
|
||||
- vector_store_config: vector store config params.
|
||||
- source_reader: Optional[BaseLoader]
|
||||
- text_splitter: Optional[TextSplitter]
|
||||
- embedding_args: Optional
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.vector_store_config = vector_store_config or {}
|
||||
self.source_reader = source_reader or None
|
||||
self.text_splitter = text_splitter or None
|
||||
self.embedding_args = embedding_args
|
||||
self.embeddings = self.vector_store_config.get("embeddings", None)
|
||||
|
||||
@abstractmethod
|
||||
@register
|
||||
def read(self) -> List[ABC]:
|
||||
"""read datasource into document objects."""
|
||||
|
||||
@register
|
||||
def data_process(self, text):
|
||||
"""pre process data.
|
||||
Args:
|
||||
- text: raw text
|
||||
"""
|
||||
|
||||
@register
|
||||
def text_splitter(self, text_splitter: TextSplitter):
|
||||
"""add text split chunk
|
||||
Args:
|
||||
- text_splitter: TextSplitter
|
||||
"""
|
||||
pass
|
||||
|
||||
@register
|
||||
def text_to_vector(self, docs):
|
||||
"""transform vector
|
||||
Args:
|
||||
- docs: List[Document]
|
||||
"""
|
||||
pass
|
||||
|
||||
@register
|
||||
def index_to_store(self, docs):
|
||||
"""index to vector store
|
||||
Args:
|
||||
- docs: List[Document]
|
||||
"""
|
||||
self.vector_client = VectorStoreConnector(
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
return self.vector_client.load_document(docs)
|
||||
|
||||
@register
|
||||
def similar_search(self, doc, topk):
|
||||
"""vector store similarity_search
|
||||
Args:
|
||||
- query: query
|
||||
"""
|
||||
self.vector_client = VectorStoreConnector(
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
# https://github.com/chroma-core/chroma/issues/657
|
||||
ans = self.vector_client.similar_search(doc, topk)
|
||||
# ans = self.vector_client.similar_search(doc, 1)
|
||||
return ans
|
||||
|
||||
def vector_name_exist(self):
|
||||
self.vector_client = VectorStoreConnector(
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
return self.vector_client.vector_name_exists()
|
||||
|
||||
def source_embedding(self):
|
||||
"""read()->data_process()->text_split()->index_to_store()"""
|
||||
if "read" in registered_methods:
|
||||
text = self.read()
|
||||
if "data_process" in registered_methods:
|
||||
text = self.data_process(text)
|
||||
if "text_split" in registered_methods:
|
||||
self.text_split(text)
|
||||
if "text_to_vector" in registered_methods:
|
||||
self.text_to_vector(text)
|
||||
if "index_to_store" in registered_methods:
|
||||
self.index_to_store(text)
|
||||
|
||||
def read_batch(self):
|
||||
if "read" in registered_methods:
|
||||
text = self.read()
|
||||
if "data_process" in registered_methods:
|
||||
text = self.data_process(text)
|
||||
if "text_split" in registered_methods:
|
||||
self.text_split(text)
|
||||
return text
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user