refactor: RAG Refactor (#985)

Co-authored-by: Aralhi <xiaoping0501@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt 2024-01-03 09:45:26 +08:00 committed by GitHub
parent 90775aad50
commit 9ad70a2961
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
206 changed files with 5766 additions and 2419 deletions

View File

@ -17,7 +17,7 @@ setup: ## Set up the Python development environment
$(VENV_BIN)/pip install -r requirements/lint-requirements.txt
testenv: setup ## Set up the Python test environment
$(VENV_BIN)/pip install -e ".[simple_framework]"
$(VENV_BIN)/pip install -e ".[default]"
.PHONY: fmt
fmt: setup ## Format Python code

View File

@ -30,7 +30,7 @@ def initialize_components(
system_app.register_instance(controller)
# Register global default RAGGraphFactory
# from dbgpt.graph_engine.graph_factory import DefaultRAGGraphFactory
# from dbgpt.graph.graph_factory import DefaultRAGGraphFactory
# system_app.register(DefaultRAGGraphFactory)

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from typing import Any, Type, TYPE_CHECKING
from dbgpt.component import ComponentType, SystemApp
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings

View File

@ -14,10 +14,10 @@ from dbgpt.app.knowledge.request.request import (
DocumentQueryRequest,
)
from dbgpt.rag.embedding_engine.knowledge_type import KnowledgeType
from dbgpt.app.knowledge.request.request import DocumentSyncRequest
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.rag.knowledge.base import KnowledgeType
HTTP_HEADERS = {"Content-Type": "application/json"}

View File

@ -2,6 +2,7 @@ import os
import shutil
import tempfile
import logging
from typing import List
from fastapi import APIRouter, File, UploadFile, Form
@ -13,10 +14,10 @@ from dbgpt.configs.model_config import (
from dbgpt.app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator
from dbgpt.app.openapi.api_view_model import Result
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.app.knowledge.request.request import (
KnowledgeQueryRequest,
KnowledgeQueryResponse,
@ -27,9 +28,14 @@ from dbgpt.app.knowledge.request.request import (
SpaceArgumentRequest,
EntityExtractRequest,
DocumentSummaryRequest,
KnowledgeSyncRequest,
)
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.rag.knowledge.base import ChunkStrategy
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.tracer import root_tracer, SpanType
logger = logging.getLogger(__name__)
@ -103,6 +109,39 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
return Result.failed(code="E000X", msg=f"document add error {e}")
@router.get("/knowledge/document/chunkstrategies")
def chunk_strategies():
"""Get chunk strategies"""
print(f"/document/chunkstrategies:")
try:
return Result.succ(
[
{
"strategy": strategy.name,
"name": strategy.value[2],
"description": strategy.value[3],
"parameters": strategy.value[1],
"suffix": [
knowledge.document_type().value
for knowledge in KnowledgeFactory.subclasses()
if strategy in knowledge.support_chunk_strategy()
and knowledge.document_type() is not None
],
"type": set(
[
knowledge.type().value
for knowledge in KnowledgeFactory.subclasses()
if strategy in knowledge.support_chunk_strategy()
]
),
}
for strategy in ChunkStrategy
]
)
except Exception as e:
return Result.failed(code="E000X", msg=f"chunk strategies error {e}")
@router.post("/knowledge/{space_name}/document/list")
def document_list(space_name: str, query_request: DocumentQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
@ -189,6 +228,18 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
return Result.failed(code="E000X", msg=f"document sync error {e}")
@router.post("/knowledge/{space_name}/document/sync_batch")
def batch_document_sync(space_name: str, request: List[KnowledgeSyncRequest]):
logger.info(f"Received params: {space_name}, {request}")
try:
doc_ids = knowledge_space_service.batch_document_sync(
space_name=space_name, sync_requests=request
)
return Result.succ({"tasks": doc_ids})
except Exception as e:
return Result.failed(code="E000X", msg=f"document sync error {e}")
@router.post("/knowledge/{space_name}/chunk/list")
def document_list(space_name: str, query_request: ChunkQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
@ -204,15 +255,23 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={"vector_store_name": space_name},
embedding_factory=embedding_factory,
config = VectorStoreConfig(
name=space_name,
embedding_fn=embedding_factory.create(
EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
),
)
docs = client.similar_search(query_request.query, query_request.top_k)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
)
retriever = EmbeddingRetriever(
top_k=query_request.top_k, vector_store_connector=vector_store_connector
)
chunks = retriever.retrieve(query_request.query)
res = [
KnowledgeQueryResponse(text=d.page_content, source=d.metadata["source"])
for d in docs
KnowledgeQueryResponse(text=d.content, source=d.metadata["source"])
for d in chunks
]
return {"response": res}
@ -254,7 +313,7 @@ async def entity_extract(request: EntityExtractRequest):
logger.info(f"Received params: {request}")
try:
from dbgpt.app.scene import ChatScene
from dbgpt._private.chat_util import llm_chat_response_nostream
from dbgpt.util.chat_util import llm_chat_response_nostream
import uuid
chat_param = {

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, func
@ -51,6 +52,12 @@ class KnowledgeDocumentDao(BaseDao):
return doc_id
def get_knowledge_documents(self, query, page=1, page_size=20):
"""Get a list of documents that match the given query.
Args:
query: A KnowledgeDocumentEntity object containing the query parameters.
page: The page number to return.
page_size: The number of documents to return per page.
"""
session = self.get_raw_session()
print(f"current session:{session}")
knowledge_documents = session.query(KnowledgeDocumentEntity)
@ -85,6 +92,23 @@ class KnowledgeDocumentDao(BaseDao):
session.close()
return result
def documents_by_ids(self, ids) -> List[KnowledgeDocumentEntity]:
"""Get a list of documents by their IDs.
Args:
ids: A list of document IDs.
Returns:
A list of KnowledgeDocumentEntity objects.
"""
session = self.get_raw_session()
print(f"current session:{session}")
knowledge_documents = session.query(KnowledgeDocumentEntity)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.id.in_(ids)
)
result = knowledge_documents.all()
session.close()
return result
def get_documents(self, query):
session = self.get_raw_session()
print(f"current session:{session}")

View File

@ -3,6 +3,8 @@ from typing import List, Optional
from dbgpt._private.pydantic import BaseModel
from fastapi import UploadFile
from dbgpt.rag.chunk_manager import ChunkParameters
class KnowledgeQueryRequest(BaseModel):
"""query: knowledge query"""
@ -43,6 +45,8 @@ class DocumentQueryRequest(BaseModel):
"""doc_name: doc path"""
doc_name: str = None
"""doc_ids: doc ids"""
doc_ids: Optional[List] = None
"""doc_type: doc type"""
doc_type: str = None
"""status: status"""
@ -76,6 +80,20 @@ class DocumentSyncRequest(BaseModel):
chunk_overlap: Optional[int] = None
class KnowledgeSyncRequest(BaseModel):
"""Sync request"""
"""doc_ids: doc ids"""
doc_id: int
"""model_name: model name"""
model_name: Optional[str] = None
"""chunk_parameters: chunk parameters
"""
chunk_parameters: ChunkParameters
class ChunkQueryRequest(BaseModel):
"""id: id"""

View File

@ -1,13 +1,26 @@
import json
import logging
from datetime import datetime
from typing import List
from dbgpt.model import DefaultLLMClient
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import KnowledgeType
from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.rag.text_splitter.text_splitter import (
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
)
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt._private.config import Config
from dbgpt.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from dbgpt.component import ComponentType
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
@ -32,6 +45,7 @@ from dbgpt.app.knowledge.request.request import (
SpaceArgumentRequest,
DocumentSyncRequest,
DocumentSummaryRequest,
KnowledgeSyncRequest,
)
from enum import Enum
@ -106,7 +120,10 @@ class KnowledgeService:
content=request.content,
result="",
)
return knowledge_document_dao.create_knowledge_document(document)
doc_id = knowledge_document_dao.create_knowledge_document(document)
if doc_id is None:
raise Exception(f"create document failed, {request.doc_name}")
return doc_id
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
"""get knowledge space
@ -171,36 +188,75 @@ class KnowledgeService:
Args:
- space: Knowledge Space Name
- request: DocumentQueryRequest
Returns:
- res DocumentQueryResponse
"""
query = KnowledgeDocumentEntity(
doc_name=request.doc_name,
doc_type=request.doc_type,
space=space,
status=request.status,
)
res = DocumentQueryResponse()
res.data = knowledge_document_dao.get_knowledge_documents(
query, page=request.page, page_size=request.page_size
)
res.total = knowledge_document_dao.get_knowledge_documents_count(query)
res.page = request.page
if request.doc_ids and len(request.doc_ids) > 0:
res.data = knowledge_document_dao.documents_by_ids(request.doc_ids)
else:
query = KnowledgeDocumentEntity(
doc_name=request.doc_name,
doc_type=request.doc_type,
space=space,
status=request.status,
)
res.data = knowledge_document_dao.get_knowledge_documents(
query, page=request.page, page_size=request.page_size
)
res.total = knowledge_document_dao.get_knowledge_documents_count(query)
res.page = request.page
return res
def batch_document_sync(
self, space_name, sync_requests: List[KnowledgeSyncRequest]
) -> List[int]:
"""batch sync knowledge document chunk into vector store
Args:
- space: Knowledge Space Name
- sync_requests: List[KnowledgeSyncRequest]
Returns:
- List[int]: document ids
"""
doc_ids = []
for sync_request in sync_requests:
docs = knowledge_document_dao.documents_by_ids([sync_request.doc_id])
if len(docs) == 0:
raise Exception(
f"there are document called, doc_id: {sync_request.doc_id}"
)
doc = docs[0]
if (
doc.status == SyncStatus.RUNNING.name
or doc.status == SyncStatus.FINISHED.name
):
raise Exception(
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
)
chunk_parameters = sync_request.chunk_parameters
if "Automatic" == chunk_parameters.chunk_strategy:
space_context = self.get_space_context(space_name)
chunk_parameters.chunk_size = (
CFG.KNOWLEDGE_CHUNK_SIZE
if space_context is None
else int(space_context["embedding"]["chunk_size"])
)
chunk_parameters.chunk_overlap = (
CFG.KNOWLEDGE_CHUNK_OVERLAP
if space_context is None
else int(space_context["embedding"]["chunk_overlap"])
)
self._sync_knowledge_document(space_name, doc, chunk_parameters)
doc_ids.append(doc.id)
return doc_ids
def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
"""sync knowledge document chunk into vector store
Args:
- space: Knowledge Space Name
- sync_request: DocumentSyncRequest
"""
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding_engine.pre_text_splitter import PreTextSplitter
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
)
# import langchain is very very slow!!!
from dbgpt.rag.text_splitter.pre_text_splitter import PreTextSplitter
doc_ids = sync_request.doc_ids
self.model_name = sync_request.model_name or CFG.LLM_MODEL
@ -234,6 +290,11 @@ class KnowledgeService:
if sync_request.chunk_overlap:
chunk_overlap = sync_request.chunk_overlap
separators = sync_request.separators or None
from dbgpt.rag.chunk_manager import ChunkParameters
chunk_parameters = ChunkParameters(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
if CFG.LANGUAGE == "en":
text_splitter = RecursiveCharacterTextSplitter(
separators=separators,
@ -244,7 +305,7 @@ class KnowledgeService:
else:
if separators and len(separators) > 1:
raise ValueError(
"SpacyTextSplitter do not support multiple separators"
"SpacyTextSplitter do not support multipsle separators"
)
try:
separator = "\n\n" if not separators else separators[0]
@ -266,48 +327,51 @@ class KnowledgeService:
pre_separator=sync_request.pre_separator,
text_splitter_impl=text_splitter,
)
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
client = EmbeddingEngine(
knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={
"vector_store_name": space_name,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
},
text_splitter=text_splitter,
embedding_factory=embedding_factory,
)
chunk_docs = client.read()
# update document status
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_doc_embedding, client, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
chunk_entities = [
DocumentChunkEntity(
doc_name=doc.doc_name,
doc_type=doc.doc_type,
document_id=doc.id,
content=chunk_doc.page_content,
meta_info=str(chunk_doc.metadata),
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
for chunk_doc in chunk_docs
]
document_chunk_dao.create_documents_chunks(chunk_entities)
chunk_parameters.text_splitter = text_splitter
self._sync_knowledge_document(space_name, doc, chunk_parameters)
return doc.id
def _sync_knowledge_document(
self,
space_name,
doc: KnowledgeDocumentEntity,
chunk_parameters: ChunkParameters,
) -> List[Chunk]:
"""sync knowledge document chunk into vector store"""
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
from dbgpt.storage.vector_store.base import VectorStoreConfig
config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
)
knowledge = KnowledgeFactory.create(
datasource=doc.content,
knowledge_type=KnowledgeType.get_by_value(doc.doc_type),
)
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
vector_store_connector=vector_store_connector,
)
chunk_docs = assembler.get_chunks()
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
return chunk_docs
async def document_summary(self, request: DocumentSummaryRequest):
"""get document summary
Args:
@ -318,20 +382,46 @@ class KnowledgeService:
if len(documents) != 1:
raise Exception(f"can not found document for {request.doc_id}")
document = documents[0]
query = DocumentChunkEntity(
document_id=request.doc_id,
)
chunks = document_chunk_dao.get_document_chunks(query, page=1, page_size=100)
if len(chunks) == 0:
raise Exception(f"can not found chunks for {request.doc_id}")
from langchain.schema import Document
from dbgpt.model.cluster import WorkerManagerFactory
chunk_docs = [Document(page_content=chunk.content) for chunk in chunks]
return await self.async_document_summary(
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
chunk_parameters = ChunkParameters(
chunk_strategy="CHUNK_BY_SIZE",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=CFG.KNOWLEDGE_CHUNK_OVERLAP,
)
chunk_entities = document_chunk_dao.get_document_chunks(
DocumentChunkEntity(document_id=document.id)
)
if (
document.status not in [SyncStatus.RUNNING.name]
and len(chunk_entities) == 0
):
self._sync_knowledge_document(
space_name=document.space,
doc=document,
chunk_parameters=chunk_parameters,
)
knowledge = KnowledgeFactory.create(
datasource=document.content,
knowledge_type=KnowledgeType.get_by_value(document.doc_type),
)
assembler = SummaryAssembler(
knowledge=knowledge,
model_name=request.model_name,
chunk_docs=chunk_docs,
doc=document,
conn_uid=request.conv_uid,
llm_client=DefaultLLMClient(worker_manager=worker_manager),
language=CFG.LANGUAGE,
chunk_parameters=chunk_parameters,
)
summary = await assembler.generate_summary()
if len(assembler.get_chunks()) == 0:
raise Exception(f"can not found chunks for {request.doc_id}")
return await self._llm_extract_summary(
summary, request.conv_uid, request.model_name
)
def update_knowledge_space(
@ -354,15 +444,13 @@ class KnowledgeService:
if len(spaces) == 0:
raise Exception(f"delete error, no space name:{space_name} in database")
space = spaces[0]
vector_config = {}
vector_config["vector_store_name"] = space.name
vector_config["vector_store_type"] = CFG.VECTOR_STORE_TYPE
vector_config["chroma_persist_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH
vector_client = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE, ctx=vector_config
config = VectorStoreConfig(name=space.name)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
)
# delete vectors
vector_client.delete_vector_name(space.name)
vector_store_connector.delete_vector_name(space.name)
document_query = KnowledgeDocumentEntity(space=space.name)
# delete chunks
documents = knowledge_document_dao.get_documents(document_query)
@ -385,15 +473,13 @@ class KnowledgeService:
raise Exception(f"there are no or more than one document called {doc_name}")
vector_ids = documents[0].vector_ids
if vector_ids is not None:
vector_config = {}
vector_config["vector_store_name"] = space_name
vector_config["vector_store_type"] = CFG.VECTOR_STORE_TYPE
vector_config["chroma_persist_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH
vector_client = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE, ctx=vector_config
config = VectorStoreConfig(name=space_name)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
)
# delete vector by ids
vector_client.delete_by_ids(vector_ids)
vector_store_connector.delete_by_ids(vector_ids)
# delete chunks
document_chunk_dao.raw_delete(documents[0].id)
# delete document
@ -432,7 +518,7 @@ class KnowledgeService:
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
)
try:
from dbgpt.rag.graph_engine.graph_factory import RAGGraphFactory
from dbgpt.rag.graph.graph_factory import RAGGraphFactory
rag_engine = CFG.SYSTEM_APP.get_component(
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
@ -446,54 +532,38 @@ class KnowledgeService:
logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
return knowledge_document_dao.update_knowledge_document(doc)
async def async_document_summary(self, model_name, chunk_docs, doc, conn_uid):
"""async document extract summary
Args:
- model_name: str
- chunk_docs: List[Document]
- doc: KnowledgeDocumentEntity
"""
texts = [doc.page_content for doc in chunk_docs]
from dbgpt.util.prompt_util import PromptHelper
prompt_helper = PromptHelper()
from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
texts = prompt_helper.repack(prompt_template=prompt.template, text_chunks=texts)
logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
)
space_context = self.get_space_context(doc.space)
if space_context and space_context.get("summary"):
summary = await self._mapreduce_extract_summary(
docs=texts,
model_name=model_name,
max_iteration=int(space_context["summary"]["max_iteration"]),
concurrency_limit=int(space_context["summary"]["concurrency_limit"]),
)
else:
summary = await self._mapreduce_extract_summary(
docs=texts, model_name=model_name
)
return await self._llm_extract_summary(summary, conn_uid, model_name)
def async_doc_embedding(self, client, chunk_docs, doc):
def async_doc_embedding(self, assembler, chunk_docs, doc):
"""async document embedding into vector db
Args:
- client: EmbeddingEngine Client
- chunk_docs: List[Document]
- doc: KnowledgeDocumentEntity
"""
logger.info(
f"async doc sync, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
)
try:
vector_ids = client.knowledge_embedding_batch(chunk_docs)
vector_ids = assembler.persist()
doc.status = SyncStatus.FINISHED.name
doc.result = "document embedding success"
if vector_ids is not None:
doc.vector_ids = ",".join(vector_ids)
logger.info(f"async document embedding, success:{doc.doc_name}")
# save chunk details
chunk_entities = [
DocumentChunkEntity(
doc_name=doc.doc_name,
doc_type=doc.doc_type,
document_id=doc.id,
content=chunk_doc.content,
meta_info=str(chunk_doc.metadata),
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
for chunk_doc in chunk_docs
]
document_chunk_dao.create_documents_chunks(chunk_entities)
except Exception as e:
doc.status = SyncStatus.FAILED.name
doc.result = "document embedding failed" + str(e)
@ -577,65 +647,3 @@ class KnowledgeService:
**{"chat_param": chat_param},
)
return chat
async def _mapreduce_extract_summary(
self,
docs,
model_name: str = None,
max_iteration: int = 5,
concurrency_limit: int = 3,
):
"""Extract summary by mapreduce mode
map -> multi async call llm to generate summary
reduce -> merge the summaries by map process
Args:
docs:List[str]
model_name:model name str
max_iteration:max iteration will call llm to summary
concurrency_limit:the max concurrency threads to call llm
Returns:
Document: refine summary context document.
"""
from dbgpt.app.scene import ChatScene
from dbgpt._private.chat_util import llm_chat_response_nostream
import uuid
tasks = []
if len(docs) == 1:
return docs[0]
else:
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
for doc in docs[0:max_iteration]:
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": "",
"select_param": doc,
"model_name": model_name,
"model_cache_enable": True,
}
tasks.append(
llm_chat_response_nostream(
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
)
)
from dbgpt._private.chat_util import run_async_tasks
summary_iters = await run_async_tasks(
tasks=tasks, concurrency_limit=concurrency_limit
)
summary_iters = list(
filter(
lambda content: "LLMServer Generate Error" not in content,
summary_iters,
)
)
from dbgpt.util.prompt_util import PromptHelper
from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
prompt_helper = PromptHelper()
summary_iters = prompt_helper.repack(
prompt_template=prompt.template, text_chunks=summary_iters
)
return await self._mapreduce_extract_summary(
summary_iters, model_name, max_iteration, concurrency_limit
)

View File

@ -11,6 +11,7 @@ from dbgpt.component import ComponentType
from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.core.interface.message import OnceConversation
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.util import get_or_create_event_loop
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from dbgpt.util.tracer import root_tracer, trace
@ -58,6 +59,9 @@ class BaseChat(ABC):
chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
)
self.llm_echo = False
self.worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
self.model_cache_enable = chat_param.get("model_cache_enable", False)
### load prompt template
@ -162,6 +166,10 @@ class BaseChat(ABC):
"BaseChat.__call_base.prompt_template.format", metadata=metadata
):
current_prompt = self.prompt_template.format(**input_values)
### prompt context token adapt according to llm max context length
current_prompt = await self.prompt_context_token_adapt(
prompt=current_prompt
)
self.current_message.add_system_message(current_prompt)
llm_messages = self.generate_llm_messages()
@ -169,6 +177,7 @@ class BaseChat(ABC):
# Not new server mode, we convert the message format(List[ModelMessage]) to list of dict
# fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post
llm_messages = list(map(lambda m: m.dict(), llm_messages))
payload = {
"model": self.llm_model,
"prompt": self.generate_llm_text(),
@ -431,6 +440,39 @@ class BaseChat(ABC):
return message.content
return None
async def prompt_context_token_adapt(self, prompt) -> str:
"""prompt token adapt according to llm max context length"""
model_metadata = await self.worker_manager.get_model_metadata(
{"model": self.llm_model}
)
current_token_count = await self.worker_manager.count_token(
{"model": self.llm_model, "prompt": prompt}
)
if current_token_count == -1:
logger.warning(
"tiktoken not installed, please `pip install tiktoken` first"
)
template_define_token_count = 0
if len(self.prompt_template.template_define) > 0:
template_define_token_count = await self.worker_manager.count_token(
{
"model": self.llm_model,
"prompt": self.prompt_template.template_define,
}
)
current_token_count += template_define_token_count
if (
current_token_count + self.prompt_template.max_new_tokens
) > model_metadata.context_length:
prompt = prompt[
: (
model_metadata.context_length
- self.prompt_template.max_new_tokens
- template_define_token_count
)
]
return prompt
def generate(self, p) -> str:
"""
generate context for LLM input

View File

@ -63,14 +63,11 @@ class ChatDashboard(BaseChat):
try:
table_infos = await blocking_func_to_async(
self._executor,
client.get_similar_tables,
client.get_db_summary,
self.db_name,
self.current_user_input,
self.top_k,
)
# table_infos = client.get_similar_tables(
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
# )
print("dashboard vector find tables:{}", table_infos)
except Exception as e:
print("db summary find error!" + str(e))

View File

@ -19,22 +19,14 @@ class ChatFactory(metaclass=Singleton):
from dbgpt.app.scene.chat_dashboard.prompt import prompt
from dbgpt.app.scene.chat_knowledge.v1.chat import ChatKnowledge
from dbgpt.app.scene.chat_knowledge.v1.prompt import prompt
from dbgpt.app.scene.chat_knowledge.inner_db_summary.chat import (
InnerChatDBSummary,
)
from dbgpt.app.scene.chat_knowledge.inner_db_summary.prompt import prompt
from dbgpt.app.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
from dbgpt.app.scene.chat_knowledge.extract_triplet.prompt import prompt
from dbgpt.app.scene.chat_knowledge.extract_entity.chat import ExtractEntity
from dbgpt.app.scene.chat_knowledge.extract_entity.prompt import prompt
from dbgpt.app.scene.chat_knowledge.summary.chat import ExtractSummary
from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
from dbgpt.app.scene.chat_knowledge.refine_summary.chat import (
ExtractRefineSummary,
)
from dbgpt.app.scene.chat_knowledge.refine_summary.prompt import prompt
from dbgpt.app.scene.chat_knowledge.rewrite.chat import QueryRewrite
from dbgpt.app.scene.chat_knowledge.rewrite.prompt import prompt
from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
from dbgpt.app.scene.chat_data.chat_excel.excel_learning.prompt import prompt

View File

@ -1,40 +0,0 @@
from typing import Dict
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.util.tracer import trace
class InnerChatDBSummary(BaseChat):
chat_scene: str = ChatScene.InnerChatDBSummary.value()
"""Number of results to return from the query"""
def __init__(
self,
chat_session_id,
user_input,
db_select,
db_summary,
):
""" """
super().__init__(
chat_mode=ChatScene.InnerChatDBSummary,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=db_select,
)
self.db_input = db_select
self.db_summary = db_summary
@trace()
async def generate_input_values(self) -> Dict:
input_values = {
"db_input": self.db_input,
"db_profile_summary": self.db_summary,
}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.InnerChatDBSummary.value

View File

@ -1,17 +0,0 @@
import logging
from dbgpt.core.interface.output_parser import BaseOutputParser
logger = logging.getLogger(__name__)
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str)
return clean_str
def parse_view_response(self, ai_text, data) -> str:
return ai_text
def get_format_instructions(self) -> str:
pass

View File

@ -1,45 +0,0 @@
import json
from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
from dbgpt.app.scene import ChatScene
from dbgpt.app.scene.chat_knowledge.inner_db_summary.out_parser import (
NormalChatOutputParser,
)
CFG = Config()
PROMPT_SCENE_DEFINE = """"""
_DEFAULT_TEMPLATE = """
Based on the following known database information?, answer which tables are involved in the user input.
Known database information:{db_profile_summary}
Input:{db_input}
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
"""
PROMPT_RESPONSE = """You must respond in JSON format as following format:
{response}
The response format must be JSON, and the key of JSON must be "table".
"""
RESPONSE_FORMAT = {"table": ["orders", "products"]}
PROMPT_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.InnerChatDBSummary.value(),
input_variables=["db_profile_summary", "db_input", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -19,7 +19,6 @@ _DEFAULT_TEMPLATE_ZH = (
_DEFAULT_TEMPLATE_EN = """
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.
\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1.2.and 3.
"""
_DEFAULT_TEMPLATE = (

View File

@ -1,32 +0,0 @@
from typing import Dict
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.app.scene.chat_knowledge.rewrite.prompt import prompt
class QueryRewrite(BaseChat):
chat_scene: str = ChatScene.QueryRewrite.value()
"""query rewrite by llm"""
def __init__(self, chat_param: Dict):
""" """
chat_param["chat_mode"] = ChatScene.QueryRewrite
super().__init__(
chat_param=chat_param,
)
self.nums = chat_param["select_param"]
self.current_user_input = chat_param["current_user_input"]
async def generate_input_values(self):
input_values = {
"nums": self.nums,
"original_query": self.current_user_input,
}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.QueryRewrite.value

View File

@ -1,42 +0,0 @@
import logging
from dbgpt.core.interface.output_parser import BaseOutputParser
logger = logging.getLogger(__name__)
class QueryRewriteParser(BaseOutputParser):
def __init__(self, is_stream_out: bool, **kwargs):
super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_prompt_response(self, response, max_length: int = 128):
lowercase = True
try:
results = []
response = response.strip()
if response.startswith("queries:"):
response = response[len("queries:") :]
queries = response.split(",")
if len(queries) == 1:
queries = response.split("")
if len(queries) == 1:
queries = response.split("?")
if len(queries) == 1:
queries = response.split("")
for k in queries:
rk = k
if lowercase:
rk = rk.lower()
s = rk.strip()
if s == "":
continue
results.append(s)
except Exception as e:
logger.error(f"parse query rewrite prompt_response error: {e}")
return []
return results
def parse_view_response(self, speak, data) -> str:
return data

View File

@ -1,41 +0,0 @@
from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
from dbgpt.app.scene import ChatScene
from .out_parser import QueryRewriteParser
CFG = Config()
PROMPT_SCENE_DEFINE = """You are a helpful assistant that generates multiple search queries based on a single input query."""
_DEFAULT_TEMPLATE_ZH = """请根据原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries<queries>'
"original_query{original_query}\n"
"queries\n"
"""
_DEFAULT_TEMPLATE_EN = """
Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'\n":
"original query:: {original_query}\n"
"queries:\n"
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_RESPONSE = """"""
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.QueryRewrite.value(),
input_variables=["nums", "original_query"],
response_format=None,
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=QueryRewriteParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -1,28 +0,0 @@
from typing import Dict
from dbgpt.app.scene import BaseChat, ChatScene
class ExtractSummary(BaseChat):
chat_scene: str = ChatScene.ExtractSummary.value()
"""get summary by llm"""
def __init__(self, chat_param: Dict):
""" """
chat_param["chat_mode"] = ChatScene.ExtractSummary
super().__init__(
chat_param=chat_param,
)
self.user_input = chat_param["select_param"]
async def generate_input_values(self):
input_values = {
"context": self.user_input,
}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ExtractSummary.value

View File

@ -1,28 +0,0 @@
import logging
from typing import List, Tuple
from dbgpt.core.interface.output_parser import BaseOutputParser, ResponseTye
logger = logging.getLogger(__name__)
class ExtractSummaryParser(BaseOutputParser):
def __init__(self, is_stream_out: bool, **kwargs):
super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_prompt_response(
self, response, max_length: int = 128
) -> List[Tuple[str, str, str]]:
# clean_str = super().parse_prompt_response(response)
print("clean prompt response:", response)
return response
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
return data
def parse_model_nostream_resp(self, response: ResponseTye, sep: str) -> str:
try:
return super().parse_model_nostream_resp(response, sep)
except Exception as e:
return str(e)

View File

@ -1,46 +0,0 @@
from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
from dbgpt.app.scene import ChatScene
from dbgpt.app.scene.chat_knowledge.summary.out_parser import ExtractSummaryParser
CFG = Config()
# PROMPT_SCENE_DEFINE = """You are an expert Q&A system that is trusted around the world.\nAlways answer the query using the provided context information, and not prior knowledge.\nSome rules to follow:\n1. Never directly reference the given context in your answer.\n2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines."""
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
The assistant gives helpful, detailed, professional and polite answers to the user's questions."""
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结:
{context}
答案尽量精确和简单,不要过长长度控制在100字左右
"""
_DEFAULT_TEMPLATE_EN = """
Write a quick summary of the following context:
{context}
the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters.
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_RESPONSE = """"""
RESPONSE_FORMAT = """"""
PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ExtractSummary.value(),
input_variables=["context"],
response_format=None,
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=ExtractSummaryParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -5,6 +5,7 @@ from typing import Dict, List
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt._private.config import Config
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
@ -16,7 +17,9 @@ from dbgpt.app.knowledge.document_db import (
KnowledgeDocumentEntity,
)
from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.util.tracer import trace
CFG = Config()
@ -35,8 +38,7 @@ class ChatKnowledge(BaseChat):
- model_name:(str) llm model name
- select_param:(str) space name
"""
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
self.knowledge_space = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatKnowledge
@ -59,17 +61,37 @@ class ChatKnowledge(BaseChat):
if self.space_context is None or self.space_context.get("prompt") is None
else int(self.space_context["prompt"]["max_token"])
)
vector_store_config = {
"vector_store_name": self.knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
self.knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
from dbgpt.storage.vector_store.base import VectorStoreConfig
config = VectorStoreConfig(name=self.knowledge_space, embedding_fn=embedding_fn)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
)
query_rewrite = None
self.worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
self.llm_client = DefaultLLMClient(worker_manager=self.worker_manager)
if CFG.KNOWLEDGE_SEARCH_REWRITE:
query_rewrite = QueryRewrite(
llm_client=self.llm_client,
model_name=self.llm_model,
language=CFG.LANGUAGE,
)
self.embedding_retriever = EmbeddingRetriever(
top_k=self.top_k,
vector_store_connector=vector_store_connector,
query_rewrite=query_rewrite,
)
self.prompt_template.template_is_strict = False
self.relations = None
@ -110,50 +132,31 @@ class ChatKnowledge(BaseChat):
if self.space_context and self.space_context.get("prompt"):
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"]
from dbgpt.rag.retriever.reinforce import QueryReinforce
from dbgpt.util.chat_util import run_async_tasks
# query reinforce, get similar queries
query_reinforce = QueryReinforce(
query=self.current_user_input, model_name=self.llm_model
)
queries = []
if CFG.KNOWLEDGE_SEARCH_REWRITE:
queries = await query_reinforce.rewrite()
print("rewrite queries:", queries)
queries.append(self.current_user_input)
from dbgpt._private.chat_util import run_async_tasks
# similarity search from vector db
tasks = [self.execute_similar_search(query) for query in queries]
docs_with_scores = await run_async_tasks(tasks=tasks, concurrency_limit=1)
candidates_with_scores = reduce(lambda x, y: x + y, docs_with_scores)
# candidates document rerank
from dbgpt.rag.retriever.rerank import DefaultRanker
ranker = DefaultRanker(self.top_k)
candidates_with_scores = ranker.rank(candidates_with_scores)
tasks = [self.execute_similar_search(self.current_user_input)]
candidates_with_scores = await run_async_tasks(tasks=tasks, concurrency_limit=1)
candidates_with_scores = reduce(lambda x, y: x + y, candidates_with_scores)
self.chunks_with_score = []
if not candidates_with_scores or len(candidates_with_scores) == 0:
print("no relevant docs to retrieve")
context = "no relevant docs to retrieve"
else:
self.chunks_with_score = []
for d, score in candidates_with_scores:
for chunk in candidates_with_scores:
chucks = self.chunk_dao.get_document_chunks(
query=DocumentChunkEntity(content=d.page_content),
query=DocumentChunkEntity(content=chunk.content),
document_ids=self.document_ids,
)
if len(chucks) > 0:
self.chunks_with_score.append((chucks[0], score))
self.chunks_with_score.append((chucks[0], chunk.score))
context = [doc.page_content for doc, _ in candidates_with_scores]
context = context[: self.max_token]
context = "\n".join([doc.content for doc in candidates_with_scores])
self.relations = list(
set(
[
os.path.basename(str(d.metadata.get("source", "")))
for d, _ in candidates_with_scores
for d in candidates_with_scores
]
)
)
@ -201,7 +204,8 @@ class ChatKnowledge(BaseChat):
references_list = list(references_dict.values())
references_ele.set("references", json.dumps(references_list))
html = ET.tostring(references_ele, encoding="utf-8")
return html.decode("utf-8")
reference = html.decode("utf-8")
return reference.replace("\\n", "")
@property
def chat_type(self) -> str:
@ -213,10 +217,6 @@ class ChatKnowledge(BaseChat):
async def execute_similar_search(self, query):
"""execute similarity search"""
return await blocking_func_to_async(
self._executor,
self.knowledge_embedding_client.similar_search_with_scores,
query,
self.top_k,
self.recall_score,
return await self.embedding_retriever.aretrieve_with_scores(
query, self.recall_score
)

View File

@ -2,16 +2,20 @@ from typing import Dict, Optional, List
from dataclasses import dataclass
import datetime
import os
from dbgpt.configs.model_config import PILOT_PATH
from dbgpt.core.awel import MapOperator
from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
from dbgpt.app.scene import ChatScene
from dbgpt.core.interface.message import OnceConversation
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
# TODO move global config
CFG = Config()
@ -184,23 +188,14 @@ class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]):
async def map(self, input_value: ChatContext) -> ChatContext:
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
# TODO, decompose the current operator into some atomic operators
knowledge_space = input_value.select_param
vector_store_config = {
"vector_store_name": knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
embedding_factory = self.system_app.get_component(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
space_context = await self._get_space_context(knowledge_space)
top_k = (
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
@ -219,16 +214,28 @@ class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]):
]
input_value.prompt_template.template = space_context["prompt"]["template"]
config = VectorStoreConfig(
name=knowledge_space,
embedding_fn=embedding_factory.create(
EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
),
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
)
embedding_retriever = EmbeddingRetriever(
top_k=top_k, vector_store_connector=vector_store_connector
)
docs = await self.blocking_func_to_async(
knowledge_embedding_client.similar_search,
embedding_retriever.retrieve,
input_value.current_user_input,
top_k,
)
if not docs or len(docs) == 0:
print("no relevant docs to retrieve")
context = "no relevant docs to retrieve"
else:
context = [d.page_content for d in docs]
context = [d.content for d in docs]
context = context[:max_token]
relations = list(
set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs])

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1 +0,0 @@
self.__BUILD_MANIFEST=function(s,c,a,e,t,d,n,f,k,h,i,b){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":["static/chunks/29107295-90b90cb30c825230.js",s,c,a,d,n,f,k,"static/chunks/412-b911d4a677c64b70.js","static/chunks/981-ff77d5cc3ab95298.js","static/chunks/pages/index-d1740e3bc6dba7f5.js"],"/_error":["static/chunks/pages/_error-dee72aff9b2e2c12.js"],"/agent":[s,c,e,d,t,n,"static/chunks/pages/agent-92e9dce47267e88d.js"],"/chat":["static/chunks/pages/chat-84fbba4764166684.js"],"/chat/[scene]/[id]":["static/chunks/pages/chat/[scene]/[id]-f665336966e79cc9.js"],"/database":[s,c,a,e,t,f,h,"static/chunks/643-d8f53f40dd3c5b40.js","static/chunks/pages/database-3140f507fe61ccb8.js"],"/knowledge":[i,s,c,e,d,t,n,f,"static/chunks/551-266086fbfa0925ec.js","static/chunks/pages/knowledge-8ada4ce8fa909bf5.js"],"/knowledge/chunk":[e,t,"static/chunks/pages/knowledge/chunk-9f117a5ed799edd3.js"],"/models":[i,s,c,a,b,h,"static/chunks/pages/models-80218c46bc1d8cfa.js"],"/prompt":[s,c,a,b,"static/chunks/837-e6d4d1eb9e057050.js",k,"static/chunks/607-b224c640f6907e4b.js","static/chunks/pages/prompt-7f839dfd56bc4c20.js"],sortedPages:["/","/_app","/_error","/agent","/chat","/chat/[scene]/[id]","/database","/knowledge","/knowledge/chunk","/models","/prompt"]}}("static/chunks/64-91b49d45b9846775.js","static/chunks/479-b20198841f9a6a1e.js","static/chunks/9-bb2c54d5c06ba4bf.js","static/chunks/442-197e6cbc1e54109a.js","static/chunks/813-cce9482e33f2430c.js","static/chunks/553-df5701294eedae07.js","static/chunks/924-ba8e16df4d61ff5c.js","static/chunks/411-d9eba2657c72f766.js","static/chunks/270-2f094a936d056513.js","static/chunks/928-74244889bd7f2699.js","static/chunks/75fc9c18-a784766a129ec5fb.js","static/chunks/947-5980a3ff49069ddd.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();

View File

@ -0,0 +1 @@
self.__BUILD_MANIFEST=function(s,c,a,e,t,n,d,b,f,k,h,i,u,j){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":["static/chunks/29107295-90b90cb30c825230.js",s,c,e,a,d,b,f,h,"static/chunks/861-78929b4f98dbbfd6.js","static/chunks/161-96143606b49cf4a1.js","static/chunks/pages/index-a5e7e7433070d21b.js"],"/_error":["static/chunks/pages/_error-dee72aff9b2e2c12.js"],"/agent":[s,c,a,t,d,n,"static/chunks/pages/agent-a2599efbeb46e056.js"],"/chat":["static/chunks/pages/chat-47a20abbae16e858.js"],"/chat/[scene]/[id]":["static/chunks/pages/chat/[scene]/[id]-8df445f91cde33fa.js"],"/database":[s,c,e,a,t,n,f,k,"static/chunks/643-d8f53f40dd3c5b40.js","static/chunks/pages/database-d36f41810fc357a6.js"],"/knowledge":[i,s,c,a,t,d,n,f,u,k,h,"static/chunks/450-bd680f0e37e9b4b9.js","static/chunks/pages/knowledge-b9300e7addf1931f.js"],"/knowledge/chunk":[s,e,t,b,n,"static/chunks/pages/knowledge/chunk-652744b9d90c26c9.js"],"/models":[i,s,c,e,a,j,k,"static/chunks/pages/models-1145859ba0e2f20a.js"],"/prompt":[s,c,e,a,j,b,u,"static/chunks/346-b0aea1c99abd6f1e.js","static/chunks/607-2dedaf19149304c0.js","static/chunks/pages/prompt-fca5ed813d5018b1.js"],sortedPages:["/","/_app","/_error","/agent","/chat","/chat/[scene]/[id]","/database","/knowledge","/knowledge/chunk","/models","/prompt"]}}("static/chunks/113-15fc0b8bd2b5b9a1.js","static/chunks/17-d6c52cecd9ecc451.js","static/chunks/479-33b3ebe9be79a971.js","static/chunks/9-bb2c54d5c06ba4bf.js","static/chunks/442-197e6cbc1e54109a.js","static/chunks/813-cce9482e33f2430c.js","static/chunks/553-a89ad624ca0f1ffa.js","static/chunks/810-84757da754c6f3fc.js","static/chunks/411-b5d3e7f64bee2335.js","static/chunks/928-74244889bd7f2699.js","static/chunks/234-42f62dc360b2d9e4.js","static/chunks/75fc9c18-a784766a129ec5fb.js","static/chunks/45-9ff739c09925ea35.js","static/chunks/947-5980a3ff49069ddd.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

116
dbgpt/rag/chunk.py Normal file
View File

@ -0,0 +1,116 @@
import json
import uuid
from typing import Any, Dict
from pydantic import Field, BaseModel
class Document(BaseModel):
"""Document including document content, document metadata."""
content: str = (Field(default="", description="document text content"),)
metadata: Dict[str, Any] = (
Field(
default_factory=dict,
description="metadata fields",
),
)
def set_content(self, content: str) -> None:
"""Set the content"""
self.content = content
def get_content(self) -> str:
return self.content
@classmethod
def langchain2doc(cls, document):
"""Transformation from Langchain to Chunk Document format."""
metadata = document.metadata or {}
return cls(content=document.page_content, metadata=metadata)
@classmethod
def doc2langchain(cls, chunk):
"""Transformation from Chunk to Langchain Document format."""
from langchain.schema import Document as LCDocument
return LCDocument(page_content=chunk.content, metadata=chunk.metadata)
class Chunk(Document):
"""
Document Chunk including chunk content, chunk metadata, chunk summary, chunk relations.
"""
chunk_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="unique id for the chunk"
)
content: str = Field(default="", description="chunk text content")
metadata: Dict[str, Any] = (
Field(
default_factory=dict,
description="metadata fields",
),
)
score: float = Field(default=0.0, description="chunk text similarity score")
summary: str = Field(default="", description="chunk text summary")
separator: str = Field(
default="\n",
description="Separator between metadata fields when converting to string.",
)
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
data = self.dict(**kwargs)
data["class_name"] = self.class_name()
return data
def to_json(self, **kwargs: Any) -> str:
data = self.to_dict(**kwargs)
return json.dumps(data)
def __hash__(self):
return hash((self.chunk_id,))
def __eq__(self, other):
return self.chunk_id == other.chunk_id
@classmethod
def from_dict(cls, data: Dict[str, Any], **kwargs: Any): # type: ignore
if isinstance(kwargs, dict):
data.update(kwargs)
data.pop("class_name", None)
return cls(**data)
@classmethod
def from_json(cls, data_str: str, **kwargs: Any): # type: ignore
data = json.loads(data_str)
return cls.from_dict(data, **kwargs)
@classmethod
def langchain2chunk(cls, document):
"""Transformation from Langchain to Chunk Document format."""
metadata = document.metadata or {}
return cls(content=document.page_content, metadata=document.metadata)
@classmethod
def llamaindex2chunk(cls, node):
"""Transformation from LLama-Index to Chunk Document format."""
metadata = node.metadata or {}
return cls(content=node.content, metadata=metadata)
@classmethod
def chunk2langchain(cls, chunk):
"""Transformation from Chunk to Langchain Document format."""
from langchain.schema import Document as LCDocument
return LCDocument(page_content=chunk.content, metadata=chunk.metadata)
@classmethod
def chunk2llamaindex(cls, chunk):
"""Transformation from Chunk to LLama-Index Document format."""
from llama_index.schema import TextNode
return TextNode(text=chunk.content, metadata=chunk.metadata)

137
dbgpt/rag/chunk_manager.py Normal file
View File

@ -0,0 +1,137 @@
from enum import Enum
from typing import Optional, List, Any
from pydantic import BaseModel, Field
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.extractor.base import Extractor
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge
class SplitterType(Enum):
"""splitter type"""
LANGCHAIN = "langchain"
LLAMA_INDEX = "llama-index"
USER_DEFINE = "user_define"
class ChunkParameters(BaseModel):
"""ChunkParameters"""
chunk_strategy: str = Field(
default=None,
description="chunk strategy",
)
text_splitter: Optional[Any] = Field(
default=None,
description="text splitter",
)
splitter_type: SplitterType = Field(
default=SplitterType.USER_DEFINE,
description="splitter type",
)
chunk_size: int = Field(
default=512,
description="chunk size",
)
chunk_overlap: int = Field(
default=50,
description="chunk overlap",
)
separator: str = Field(
default="\n",
description="chunk separator",
)
class ChunkManager:
"""ChunkManager"""
def __init__(
self,
knowledge: Knowledge = None,
chunk_parameter: Optional[ChunkParameters] = None,
extractor: Optional[Extractor] = None,
):
"""
Args:
knowledge: (Knowledge) Knowledge datasource.
chunk_parameter: (Optional[ChunkParameter]) Chunk parameter.
extractor: (Optional[Extractor]) Extractor to use for summarization.
"""
self._knowledge = knowledge
self._extractor = extractor
self._chunk_parameters = chunk_parameter or ChunkParameters()
self._chunk_strategy = (
chunk_parameter.chunk_strategy
or self._knowledge.default_chunk_strategy().name
)
self._text_splitter = chunk_parameter.text_splitter
self._splitter_type = chunk_parameter.splitter_type
def split(self, documents) -> List[Chunk]:
"""Split a document into chunks."""
text_splitter = self._select_text_splitter()
if SplitterType.LANGCHAIN == self._splitter_type:
documents = text_splitter.split_documents(documents)
return [Chunk.langchain2chunk(document) for document in documents]
elif SplitterType.LLAMA_INDEX == self._splitter_type:
nodes = text_splitter.split_text(documents)
return [Chunk.llamaindex2chunk(node) for node in nodes]
else:
return text_splitter.split_documents(documents)
def split_with_summary(
self, document: Any, chunk_strategy: ChunkStrategy
) -> List[Chunk]:
"""Split a document into chunks and summary"""
raise NotImplementedError
@property
def chunk_parameters(self) -> ChunkParameters:
return self._chunk_parameters
def set_text_splitter(
self,
text_splitter,
splitter_type: Optional[SplitterType] = SplitterType.LANGCHAIN,
) -> None:
"""Add text splitter."""
self._text_splitter = text_splitter
self._splitter_type = splitter_type
def get_text_splitter(
self,
) -> Any:
"""get text splitter."""
return self._select_text_splitter()
def _select_text_splitter(
self,
):
"""Select text splitter by chunk strategy."""
if self._text_splitter:
return self._text_splitter
if not self._chunk_strategy or "Automatic" == self._chunk_strategy:
self._chunk_strategy = self._knowledge.default_chunk_strategy().name
if self._chunk_strategy not in [
support_chunk_strategy.name
for support_chunk_strategy in self._knowledge.support_chunk_strategy()
]:
current_type = self._knowledge.type().value
if self._knowledge.document_type():
current_type = self._knowledge.document_type().value
raise ValueError(
f"{current_type} knowledge not supported chunk strategy {self._chunk_strategy} "
)
strategy = ChunkStrategy[self._chunk_strategy]
return strategy.match(
chunk_size=self._chunk_parameters.chunk_size,
chunk_overlap=self._chunk_parameters.chunk_overlap,
separator=self._chunk_parameters.separator,
)

View File

@ -3,12 +3,15 @@ from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING
from dbgpt.component import BaseComponent
from dbgpt.rag.embedding.embeddings import HuggingFaceEmbeddings
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
from dbgpt.rag.embedding.embeddings import Embeddings
class EmbeddingFactory(BaseComponent, ABC):
"""Abstract base class for EmbeddingFactory."""
name = "embedding_factory"
@abstractmethod
@ -41,6 +44,4 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
if embedding_cls:
return embedding_cls(**new_kwargs)
else:
from langchain.embeddings import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(**new_kwargs)

View File

@ -0,0 +1,363 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import requests
from pydantic import Field, Extra, BaseModel
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
DEFAULT_BGE_MODEL = "BAAI/bge-large-en"
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
DEFAULT_QUERY_INSTRUCTION = (
"Represent the question for retrieving supporting documents: "
)
DEFAULT_QUERY_BGE_INSTRUCTION_EN = (
"Represent this question for searching relevant passages: "
)
DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:"
class Embeddings(ABC):
"""Interface for embedding models."""
"""refer to https://github.com/langchain-ai/langchain/tree/master/libs/langchain/langchain/embeddings"""
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_documents, texts
)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_query, text
)
class HuggingFaceEmbeddings(BaseModel, Embeddings):
"""HuggingFace sentence_transformers embedding models.
To use, you should have the ``sentence_transformers`` python package installed.
Refer to https://github.com/langchain-ai/langchain/tree/master/libs/langchain/langchain/embeddings
Example:
.. code-block:: python
from .embeddings import HuggingFaceEmbeddings
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
hf = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
"""
client: Any #: :meta private:
model_name: str = DEFAULT_MODEL_NAME
"""Model name to use."""
cache_folder: Optional[str] = None
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method of the model."""
multi_process: bool = False
"""Run encode() on multiple GPUs."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
import sentence_transformers
except ImportError as exc:
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`."
) from exc
self.client = sentence_transformers.SentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
import sentence_transformers
texts = list(map(lambda x: x.replace("\n", " "), texts))
if self.multi_process:
pool = self.client.start_multi_process_pool()
embeddings = self.client.encode_multi_process(texts, pool)
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
else:
embeddings = self.client.encode(texts, **self.encode_kwargs)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
"""Wrapper around sentence_transformers embedding models.
To use, you should have the ``sentence_transformers``
and ``InstructorEmbedding`` python packages installed.
Example:
.. code-block:: python
from langchain.embeddings import HuggingFaceInstructEmbeddings
model_name = "hkunlp/instructor-large"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
hf = HuggingFaceInstructEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
"""
client: Any #: :meta private:
model_name: str = DEFAULT_INSTRUCT_MODEL
"""Model name to use."""
cache_folder: Optional[str] = None
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method of the model."""
embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
"""Instruction to use for embedding documents."""
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
"""Instruction to use for embedding query."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
from InstructorEmbedding import INSTRUCTOR
self.client = INSTRUCTOR(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
except ImportError as e:
raise ImportError("Dependencies for InstructorEmbedding not found.") from e
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace instruct model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
instruction_pairs = [[self.embed_instruction, text] for text in texts]
embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace instruct model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
instruction_pair = [self.query_instruction, text]
embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0]
return embedding.tolist()
class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
"""HuggingFace BGE sentence_transformers embedding models.
To use, you should have the ``sentence_transformers`` python package installed.
refer to https://github.com/langchain-ai/langchain/tree/master/libs/langchain/langchain/embeddings
Example:
.. code-block:: python
from langchain.embeddings import HuggingFaceBgeEmbeddings
model_name = "BAAI/bge-large-en"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
hf = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
"""
client: Any #: :meta private:
model_name: str = DEFAULT_BGE_MODEL
"""Model name to use."""
cache_folder: Optional[str] = None
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method of the model."""
query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION_EN
"""Instruction to use for embedding query."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
import sentence_transformers
except ImportError as exc:
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence_transformers`."
) from exc
self.client = sentence_transformers.SentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
if "-zh" in self.model_name:
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
texts = [t.replace("\n", " ") for t in texts]
embeddings = self.client.encode(texts, **self.encode_kwargs)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
text = text.replace("\n", " ")
embedding = self.client.encode(
self.query_instruction + text, **self.encode_kwargs
)
return embedding.tolist()
class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
"""Embed texts using the HuggingFace API.
Requires a HuggingFace Inference API key and a model name.
"""
api_key: str
"""Your API key for the HuggingFace Inference API."""
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
"""The name of the model to use for text embeddings."""
@property
def _api_url(self) -> str:
return (
"https://api-inference.huggingface.co"
"/pipeline"
"/feature-extraction"
f"/{self.model_name}"
)
@property
def _headers(self) -> dict:
return {"Authorization": f"Bearer {self.api_key}"}
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
Example:
.. code-block:: python
from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key="your_api_key",
model_name="sentence-transformers/all-MiniLM-l6-v2"
)
texts = ["Hello, world!", "How are you?"]
hf_embeddings.embed_documents(texts)
"""
response = requests.post(
self._api_url,
headers=self._headers,
json={
"inputs": texts,
"options": {"wait_for_model": True, "use_cache": True},
},
)
return response.json()
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]

View File

@ -1,12 +0,0 @@
from dbgpt.rag.embedding_engine.source_embedding import SourceEmbedding, register
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
from dbgpt.rag.embedding_engine.knowledge_type import KnowledgeType
from dbgpt.rag.embedding_engine.pre_text_splitter import PreTextSplitter
__all__ = [
"SourceEmbedding",
"register",
"EmbeddingEngine",
"KnowledgeType",
"PreTextSplitter",
]

View File

@ -1,64 +0,0 @@
from typing import List, Optional
from langchain.schema import Document
from langchain.text_splitter import (
TextSplitter,
SpacyTextSplitter,
RecursiveCharacterTextSplitter,
)
from dbgpt.rag.embedding_engine import SourceEmbedding, register
from dbgpt.rag.embedding_engine.loader.csv_loader import NewCSVLoader
class CSVEmbedding(SourceEmbedding):
"""csv embedding for read csv document."""
def __init__(
self,
file_path,
vector_store_config,
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
):
"""Initialize with csv path.
Args:
- file_path: data source path
- vector_store_config: vector store config params.
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
"""
super().__init__(
file_path, vector_store_config, source_reader=None, text_splitter=None
)
self.file_path = file_path
self.vector_store_config = vector_store_config
self.source_reader = source_reader or None
self.text_splitter = text_splitter or None
@register
def read(self):
"""Load from csv path."""
if self.source_reader is None:
self.source_reader = NewCSVLoader(self.file_path)
if self.text_splitter is None:
try:
self.text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=100,
chunk_overlap=100,
)
except Exception:
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=100, chunk_overlap=50
)
return self.source_reader.load_and_split(self.text_splitter)
@register
def data_process(self, documents: List[Document]):
i = 0
for d in documents:
documents[i].page_content = d.page_content.replace("\n", "")
i += 1
return documents

View File

@ -1,145 +0,0 @@
from typing import Optional
from langchain.text_splitter import TextSplitter
from dbgpt.rag.embedding_engine.embedding_factory import (
EmbeddingFactory,
DefaultEmbeddingFactory,
)
from dbgpt.rag.embedding_engine.knowledge_type import (
get_knowledge_embedding,
KnowledgeType,
)
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class EmbeddingEngine:
"""EmbeddingEngine provide a chain process include(read->text_split->data_process->index_store) for knowledge document embedding into vector store.
1.knowledge_embedding:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
2.similar_search: similarity search from vector_store
how to use reference:https://db-gpt.readthedocs.io/en/latest/modules/knowledge.html
how to integrate:https://db-gpt.readthedocs.io/en/latest/modules/knowledge/pdf/pdf_embedding.html
Example:
.. code-block:: python
embedding_model = "your_embedding_model"
vector_store_type = "Chroma"
chroma_persist_path = "your_persist_path"
vector_store_config = {
"vector_store_name": "document_test",
"vector_store_type": vector_store_type,
"chroma_persist_path": chroma_persist_path,
}
# it can be .md,.pdf,.docx, .csv, .html
document_path = "your_path/test.md"
embedding_engine = EmbeddingEngine(
knowledge_source=document_path,
knowledge_type=KnowledgeType.DOCUMENT.value,
model_name=embedding_model,
vector_store_config=vector_store_config,
)
# embedding document content to vector store
embedding_engine.knowledge_embedding()
"""
def __init__(
self,
model_name,
vector_store_config,
knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value,
knowledge_source: Optional[str] = None,
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
embedding_factory: EmbeddingFactory = None,
):
"""Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source
Args:
- model_name: model_name
- vector_store_config: vector store config: Dict
- knowledge_type: Optional[KnowledgeType]
- knowledge_source: Optional[str]
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
- embedding_factory: EmbeddingFactory
"""
self.knowledge_source = knowledge_source
self.model_name = model_name
self.vector_store_config = vector_store_config
self.knowledge_type = knowledge_type
if not embedding_factory:
embedding_factory = DefaultEmbeddingFactory()
self.embeddings = embedding_factory.create(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings
self.source_reader = source_reader
self.text_splitter = text_splitter
def knowledge_embedding(self):
"""source embedding is chain process.read->text_split->data_process->index_store"""
self.knowledge_embedding_client = self.init_knowledge_embedding()
self.knowledge_embedding_client.source_embedding()
def knowledge_embedding_batch(self, docs):
"""Deprecation"""
# docs = self.knowledge_embedding_client.read_batch()
return self.knowledge_embedding_client.index_to_store(docs)
def read(self):
"""Deprecation"""
self.knowledge_embedding_client = self.init_knowledge_embedding()
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self):
return get_knowledge_embedding(
self.knowledge_type,
self.knowledge_source,
self.vector_store_config,
self.source_reader,
self.text_splitter,
)
def similar_search(self, text, topk):
"""vector db similar search in vector database.
Return topk docs.
Args:
- text: query text
- topk: top k
"""
vector_client = VectorStoreConnector(
self.vector_store_config["vector_store_type"], self.vector_store_config
)
# https://github.com/chroma-core/chroma/issues/657
ans = vector_client.similar_search(text, topk)
return ans
def similar_search_with_scores(self, text, topk, score_threshold: float = 0.3):
"""
similar_search_with_score in vector database.
Return docs and relevance scores in the range [0, 1].
Args:
doc(str): query text
topk(int): return docs nums. Defaults to 4.
score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
"""
vector_client = VectorStoreConnector(
self.vector_store_config["vector_store_type"], self.vector_store_config
)
ans = vector_client.similar_search_with_scores(text, topk, score_threshold)
return ans
def vector_exist(self):
"""vector db is exist"""
vector_client = VectorStoreConnector(
self.vector_store_config["vector_store_type"], self.vector_store_config
)
return vector_client.vector_name_exists()
def delete_by_ids(self, ids):
"""delete vector db by ids
Args:
- ids: vector ids
"""
vector_client = VectorStoreConnector(
self.vector_store_config["vector_store_type"], self.vector_store_config
)
vector_client.delete_by_ids(ids=ids)

View File

@ -1,26 +0,0 @@
from typing import List, Optional
import chardet
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
class EncodeTextLoader(BaseLoader):
"""Load text files."""
def __init__(self, file_path: str, encoding: Optional[str] = None):
"""Initialize with file path."""
self.file_path = file_path
self.encoding = encoding
def load(self) -> List[Document]:
"""Load from file path."""
with open(self.file_path, "rb") as f:
raw_text = f.read()
result = chardet.detect(raw_text)
if result["encoding"] is None:
text = raw_text.decode("utf-8")
else:
text = raw_text.decode(result["encoding"])
metadata = {"source": self.file_path}
return [Document(page_content=text, metadata=metadata)]

View File

@ -1,107 +0,0 @@
from enum import Enum
from dbgpt.rag.embedding_engine.csv_embedding import CSVEmbedding
from dbgpt.rag.embedding_engine.markdown_embedding import MarkdownEmbedding
from dbgpt.rag.embedding_engine.pdf_embedding import PDFEmbedding
from dbgpt.rag.embedding_engine.ppt_embedding import PPTEmbedding
from dbgpt.rag.embedding_engine.string_embedding import StringEmbedding
from dbgpt.rag.embedding_engine.url_embedding import URLEmbedding
from dbgpt.rag.embedding_engine.word_embedding import WordEmbedding
DocumentEmbeddingType = {
".txt": (MarkdownEmbedding, {}),
".md": (MarkdownEmbedding, {}),
".html": (MarkdownEmbedding, {}),
".pdf": (PDFEmbedding, {}),
".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}),
".csv": (CSVEmbedding, {}),
".ppt": (PPTEmbedding, {}),
".pptx": (PPTEmbedding, {}),
}
class KnowledgeType(Enum):
DOCUMENT = "DOCUMENT"
URL = "URL"
TEXT = "TEXT"
OSS = "OSS"
S3 = "S3"
NOTION = "NOTION"
MYSQL = "MYSQL"
TIDB = "TIDB"
CLICKHOUSE = "CLICKHOUSE"
OCEANBASE = "OCEANBASE"
ELASTICSEARCH = "ELASTICSEARCH"
HIVE = "HIVE"
PRESTO = "PRESTO"
KAFKA = "KAFKA"
SPARK = "SPARK"
YOUTUBE = "YOUTUBE"
def get_knowledge_embedding(
knowledge_type,
knowledge_source,
vector_store_config=None,
source_reader=None,
text_splitter=None,
):
match knowledge_type:
case KnowledgeType.DOCUMENT.value:
extension = "." + knowledge_source.rsplit(".", 1)[-1]
if extension in DocumentEmbeddingType:
knowledge_class, knowledge_args = DocumentEmbeddingType[extension]
embedding = knowledge_class(
knowledge_source,
vector_store_config=vector_store_config,
source_reader=source_reader,
text_splitter=text_splitter,
**knowledge_args,
)
return embedding
raise ValueError(f"Unsupported knowledge document type '{extension}'")
case KnowledgeType.URL.value:
embedding = URLEmbedding(
file_path=knowledge_source,
vector_store_config=vector_store_config,
source_reader=source_reader,
text_splitter=text_splitter,
)
return embedding
case KnowledgeType.TEXT.value:
embedding = StringEmbedding(
file_path=knowledge_source,
vector_store_config=vector_store_config,
source_reader=source_reader,
text_splitter=text_splitter,
)
return embedding
case KnowledgeType.OSS.value:
raise Exception("OSS have not integrate")
case KnowledgeType.S3.value:
raise Exception("S3 have not integrate")
case KnowledgeType.NOTION.value:
raise Exception("NOTION have not integrate")
case KnowledgeType.MYSQL.value:
raise Exception("MYSQL have not integrate")
case KnowledgeType.TIDB.value:
raise Exception("TIDB have not integrate")
case KnowledgeType.CLICKHOUSE.value:
raise Exception("CLICKHOUSE have not integrate")
case KnowledgeType.OCEANBASE.value:
raise Exception("OCEANBASE have not integrate")
case KnowledgeType.ELASTICSEARCH.value:
raise Exception("ELASTICSEARCH have not integrate")
case KnowledgeType.HIVE.value:
raise Exception("HIVE have not integrate")
case KnowledgeType.PRESTO.value:
raise Exception("PRESTO have not integrate")
case KnowledgeType.KAFKA.value:
raise Exception("KAFKA have not integrate")
case KnowledgeType.SPARK.value:
raise Exception("SPARK have not integrate")
case KnowledgeType.YOUTUBE.value:
raise Exception("YOUTUBE have not integrate")
case _:
raise Exception("unknown knowledge type")

View File

@ -1,55 +0,0 @@
import re
from typing import List
from langchain.text_splitter import CharacterTextSplitter
class CHNDocumentSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, sentence_size: int = None, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
self.sentence_size = sentence_size
def split_text(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub("\s", " ", text)
text = re.sub("\n\n", "", text)
text = re.sub(r"([;.!?。!?\?])([^”’])", r"\1\n\2", text)
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)
text = re.sub(r'(\{2})([^"’”」』])', r"\1\n\2", text)
text = re.sub(r'([;!?。!?\?]["’”」』]{0,2})([^;!?,。!?\?])', r"\1\n\2", text)
text = text.rstrip()
ls = [i for i in text.split("\n") if i]
for ele in ls:
if len(ele) > self.sentence_size:
ele1 = re.sub(r'([,.]["’”」』]{0,2})([^,.])', r"\1\n\2", ele)
ele1_ls = ele1.split("\n")
for ele_ele1 in ele1_ls:
if len(ele_ele1) > self.sentence_size:
ele_ele2 = re.sub(
r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r"\1\n\2", ele_ele1
)
ele2_ls = ele_ele2.split("\n")
for ele_ele2 in ele2_ls:
if len(ele_ele2) > self.sentence_size:
ele_ele3 = re.sub(
'( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2
)
ele2_id = ele2_ls.index(ele_ele2)
ele2_ls = (
ele2_ls[:ele2_id]
+ [i for i in ele_ele3.split("\n") if i]
+ ele2_ls[ele2_id + 1 :]
)
ele_id = ele1_ls.index(ele_ele1)
ele1_ls = (
ele1_ls[:ele_id]
+ [i for i in ele2_ls if i]
+ ele1_ls[ele_id + 1 :]
)
id = ls.index(ele)
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1 :]
return ls

View File

@ -1,76 +0,0 @@
"""Loads a CSV file into a list of documents.
Each document represents one row of the CSV file. Every row is converted into a
key/value pair and outputted to a new line in the document's page_content.
The source for each document loaded from csv is set to the value of the
`file_path` argument for all documents by default.
You can override this by setting the `source_column` argument to the
name of a column in the CSV file.
The source of each document will then be set to the value of the column
with the name specified in `source_column`.
Output Example:
.. code-block:: txt
column1: value1
column2: value2
column3: value3
"""
from typing import Optional, Dict, List
import csv
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
class NewCSVLoader(BaseLoader):
def __init__(
self,
file_path: str,
source_column: Optional[str] = None,
csv_args: Optional[Dict] = None,
encoding: Optional[str] = None,
):
"""
Args:
file_path: The path to the CSV file.
source_column: The name of the column in the CSV file to use as the source.
Optional. Defaults to None.
csv_args: A dictionary of arguments to pass to the csv.DictReader.
Optional. Defaults to None.
encoding: The encoding of the CSV file. Optional. Defaults to None.
"""
self.file_path = file_path
self.source_column = source_column
self.encoding = encoding
self.csv_args = csv_args or {}
def load(self) -> List[Document]:
"""Load data into document objects."""
docs = []
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
strs = []
for k, v in row.items():
if k is None or v is None:
continue
strs.append(f"{k.strip()}: {v.strip()}")
content = "\n".join(strs)
try:
source = (
row[self.source_column]
if self.source_column is not None
else self.file_path
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs

View File

@ -1,28 +0,0 @@
from typing import List, Optional
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
import docx
class DocxLoader(BaseLoader):
"""Load docx files."""
def __init__(self, file_path: str, encoding: Optional[str] = None):
"""Initialize with file path."""
self.file_path = file_path
self.encoding = encoding
def load(self) -> List[Document]:
"""Load from file path."""
docs = []
doc = docx.Document(self.file_path)
content = []
for i in range(len(doc.paragraphs)):
para = doc.paragraphs[i]
text = para.text
content.append(text)
docs.append(
Document(page_content="".join(content), metadata={"source": self.file_path})
)
return docs

View File

@ -1,55 +0,0 @@
"""Loader that loads image files."""
import os
from typing import List
import fitz
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from paddleocr import PaddleOCR
class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
def _get_elements(self) -> List:
def pdf_ocr_txt(filepath, dir_path="tmp_files"):
full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
if not os.path.exists(full_dir_path):
os.makedirs(full_dir_path)
filename = os.path.split(filepath)[-1]
ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False)
doc = fitz.open(filepath)
txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename))
img_name = os.path.join(full_dir_path, ".tmp.png")
with open(txt_file_path, "w", encoding="utf-8") as fout:
for i in range(doc.page_count):
page = doc[i]
text = page.get_text("")
fout.write(text)
fout.write("\n")
img_list = page.get_images()
for img in img_list:
pix = fitz.Pixmap(doc, img[0])
pix.save(img_name)
result = ocr.ocr(img_name)
ocr_result = [i[1][0] for line in result for i in line]
fout.write("\n".join(ocr_result))
os.remove(img_name)
return txt_file_path
txt_file_path = pdf_ocr_txt(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
if __name__ == "__main__":
filepath = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test_py.pdf"
)
loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
docs = loader.load()
for doc in docs:
print(doc)

View File

@ -1,28 +0,0 @@
from typing import List, Optional
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
from pptx import Presentation
class PPTLoader(BaseLoader):
"""Load PPT files."""
def __init__(self, file_path: str, encoding: Optional[str] = None):
"""Initialize with file path."""
self.file_path = file_path
self.encoding = encoding
def load(self) -> List[Document]:
"""Load from file path."""
pr = Presentation(self.file_path)
docs = []
for slide in pr.slides:
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text:
docs.append(
Document(
page_content=shape.text, metadata={"source": slide.slide_id}
)
)
return docs

View File

@ -1,68 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List, Optional
import markdown
from bs4 import BeautifulSoup
from langchain.schema import Document
from langchain.text_splitter import (
SpacyTextSplitter,
CharacterTextSplitter,
RecursiveCharacterTextSplitter,
TextSplitter,
)
from dbgpt.rag.embedding_engine import SourceEmbedding, register
from dbgpt.rag.embedding_engine.encode_text_loader import EncodeTextLoader
class MarkdownEmbedding(SourceEmbedding):
"""markdown embedding for read markdown document."""
def __init__(
self,
file_path,
vector_store_config,
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
):
"""Initialize raw text word path."""
super().__init__(
file_path, vector_store_config, source_reader=None, text_splitter=None
)
self.file_path = file_path
self.vector_store_config = vector_store_config
self.source_reader = source_reader or None
self.text_splitter = text_splitter or None
@register
def read(self):
"""Load from markdown path."""
if self.source_reader is None:
self.source_reader = EncodeTextLoader(self.file_path)
if self.text_splitter is None:
try:
self.text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=100,
chunk_overlap=100,
)
except Exception:
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=100, chunk_overlap=50
)
return self.source_reader.load_and_split(self.text_splitter)
@register
def data_process(self, documents: List[Document]):
i = 0
for d in documents:
content = markdown.markdown(d.page_content)
soup = BeautifulSoup(content, "html.parser")
for tag in soup(["!doctype", "meta", "i.fa"]):
tag.extract()
documents[i].page_content = soup.get_text()
documents[i].page_content = documents[i].page_content.replace("\n", " ")
i += 1
return documents

View File

@ -1,66 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List, Optional
from langchain.document_loaders import PyPDFLoader
from langchain.schema import Document
from langchain.text_splitter import (
SpacyTextSplitter,
RecursiveCharacterTextSplitter,
TextSplitter,
)
from dbgpt.rag.embedding_engine import SourceEmbedding, register
class PDFEmbedding(SourceEmbedding):
"""pdf embedding for read pdf document."""
def __init__(
self,
file_path,
vector_store_config,
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
):
"""Initialize pdf word path.
Args:
- file_path: data source path
- vector_store_config: vector store config params.
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
"""
super().__init__(
file_path, vector_store_config, source_reader=None, text_splitter=None
)
self.file_path = file_path
self.vector_store_config = vector_store_config
self.source_reader = source_reader or None
self.text_splitter = text_splitter or None
@register
def read(self):
"""Load from pdf path."""
if self.source_reader is None:
self.source_reader = PyPDFLoader(self.file_path)
if self.text_splitter is None:
try:
self.text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=100,
chunk_overlap=100,
)
except Exception:
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=100, chunk_overlap=50
)
return self.source_reader.load_and_split(self.text_splitter)
@register
def data_process(self, documents: List[Document]):
i = 0
for d in documents:
documents[i].page_content = d.page_content.replace("\n", "")
i += 1
return documents

View File

@ -1,66 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List, Optional
from langchain.schema import Document
from langchain.text_splitter import (
SpacyTextSplitter,
RecursiveCharacterTextSplitter,
TextSplitter,
)
from dbgpt.rag.embedding_engine import SourceEmbedding, register
from dbgpt.rag.embedding_engine.loader.ppt_loader import PPTLoader
class PPTEmbedding(SourceEmbedding):
"""ppt embedding for read ppt document."""
def __init__(
self,
file_path,
vector_store_config,
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
):
"""Initialize ppt word path.
Args:
- file_path: data source path
- vector_store_config: vector store config params.
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
"""
super().__init__(
file_path, vector_store_config, source_reader=None, text_splitter=None
)
self.file_path = file_path
self.vector_store_config = vector_store_config
self.source_reader = source_reader or None
self.text_splitter = text_splitter or None
@register
def read(self):
"""Load from ppt path."""
if self.source_reader is None:
self.source_reader = PPTLoader(self.file_path)
if self.text_splitter is None:
try:
self.text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=100,
chunk_overlap=100,
)
except Exception:
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=100, chunk_overlap=50
)
return self.source_reader.load_and_split(self.text_splitter)
@register
def data_process(self, documents: List[Document]):
i = 0
for d in documents:
documents[i].page_content = d.page_content.replace("\n", "")
i += 1
return documents

View File

@ -1,61 +0,0 @@
# from langchain.embeddings import HuggingFaceEmbeddings
# from langchain.vectorstores import Milvus
# from pymilvus import Collection,utility
# from pymilvus import datasource, DataType, FieldSchema, CollectionSchema
#
# # milvus = datasource.connect(
# # alias="default",
# # host='localhost',
# # port="19530"
# # )
# # collection = Collection("book")
#
#
# # Get an existing collection.
# # collection.load()
# #
# # search_params = {"metric_type": "L2", "params": {}, "offset": 5}
# #
# # results = collection.search(
# # data=[[0.1, 0.2]],
# # anns_field="book_intro",
# # param=search_params,
# # limit=10,
# # expr=None,
# # output_fields=['book_id'],
# # consistency_level="Strong"
# # )
# #
# # # get the IDs of all returned hits
# # results[0].ids
# #
# # # get the distances to the query vector from all returned hits
# # results[0].distances
# #
# # # get the value of an output field specified in the search request.
# # # vector fields are not supported yet.
# # hit = results[0][0]
# # hit.entity.get('title')
#
# # milvus = datasource.connect(
# # alias="default",
# # host='localhost',
# # port="19530"
# # )
# from dbgpt.vector_store.milvus_store import MilvusStore
#
# data = ["aaa", "bbb"]
# model_name = "xx/all-MiniLM-L6-v2"
# embeddings = HuggingFaceEmbeddings(model_name=model_name)
#
# # text_embeddings = Text2Vectors()
# mivuls = MilvusStore(cfg={"url": "127.0.0.1", "port": "19530", "alias": "default", "table_name": "test_k"})
#
# mivuls.insert(["textc","tezt2"])
# print("success")
# ct
# # mivuls.from_texts(texts=data, embedding=embeddings)
# # docs,
# # embedding=embeddings,
# # connection_args={"host": "127.0.0.1", "port": "19530", "alias": "default"}
# # )

View File

@ -1,126 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from langchain.text_splitter import TextSplitter
from dbgpt.storage.vector_store.connector import VectorStoreConnector
registered_methods = []
def register(method):
registered_methods.append(method.__name__)
return method
class SourceEmbedding(ABC):
"""base class for read data source embedding pipeline.
include data read, data process, data split, data to vector, data index vector store
Implementations should implement the method
"""
def __init__(
self,
file_path,
vector_store_config: {},
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
embedding_args: Optional[Dict] = None,
):
"""Initialize with Loader url, model_name, vector_store_config
Args:
- file_path: data source path
- vector_store_config: vector store config params.
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
- embedding_args: Optional
"""
self.file_path = file_path
self.vector_store_config = vector_store_config or {}
self.source_reader = source_reader or None
self.text_splitter = text_splitter or None
self.embedding_args = embedding_args
self.embeddings = self.vector_store_config.get("embeddings", None)
@abstractmethod
@register
def read(self) -> List[ABC]:
"""read datasource into document objects."""
@register
def data_process(self, text):
"""pre process data.
Args:
- text: raw text
"""
@register
def text_splitter(self, text_splitter: TextSplitter):
"""add text split chunk
Args:
- text_splitter: TextSplitter
"""
pass
@register
def text_to_vector(self, docs):
"""transform vector
Args:
- docs: List[Document]
"""
pass
@register
def index_to_store(self, docs):
"""index to vector store
Args:
- docs: List[Document]
"""
self.vector_client = VectorStoreConnector(
self.vector_store_config["vector_store_type"], self.vector_store_config
)
return self.vector_client.load_document(docs)
@register
def similar_search(self, doc, topk):
"""vector store similarity_search
Args:
- query: query
"""
self.vector_client = VectorStoreConnector(
self.vector_store_config["vector_store_type"], self.vector_store_config
)
# https://github.com/chroma-core/chroma/issues/657
ans = self.vector_client.similar_search(doc, topk)
# ans = self.vector_client.similar_search(doc, 1)
return ans
def vector_name_exist(self):
self.vector_client = VectorStoreConnector(
self.vector_store_config["vector_store_type"], self.vector_store_config
)
return self.vector_client.vector_name_exists()
def source_embedding(self):
"""read()->data_process()->text_split()->index_to_store()"""
if "read" in registered_methods:
text = self.read()
if "data_process" in registered_methods:
text = self.data_process(text)
if "text_split" in registered_methods:
self.text_split(text)
if "text_to_vector" in registered_methods:
self.text_to_vector(text)
if "index_to_store" in registered_methods:
self.index_to_store(text)
def read_batch(self):
if "read" in registered_methods:
text = self.read()
if "data_process" in registered_methods:
text = self.data_process(text)
if "text_split" in registered_methods:
self.text_split(text)
return text

Some files were not shown because too many files have changed in this diff Show More