mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 13:58:58 +00:00
feat(RAG):add rag operators and rag awel examples (#1061)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from dbgpt.util.chat_util import run_async_tasks
|
|||||||
|
|
||||||
SUMMARY_PROMPT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结:
|
SUMMARY_PROMPT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结:
|
||||||
{context}
|
{context}
|
||||||
答案尽量精确和简单,不要过长,长度控制在100字左右
|
答案尽量精确和简单,不要过长,长度控制在100字左右, 注意:请用<中文>来进行总结。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SUMMARY_PROMPT_TEMPLATE_EN = """
|
SUMMARY_PROMPT_TEMPLATE_EN = """
|
||||||
@@ -18,6 +18,13 @@ Write a quick summary of the following context:
|
|||||||
the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters.
|
the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
REFINE_SUMMARY_TEMPLATE_ZH = """我们已经提供了一个到某一点的现有总结:{context}\n 请根据你之前推理的内容进行最终的总结,总结回答的时候最好按照1.2.3.进行. 注意:请用<中文>来进行总结。"""
|
||||||
|
|
||||||
|
REFINE_SUMMARY_TEMPLATE_EN = """
|
||||||
|
We have provided an existing summary up to a certain point: {context}, We have the opportunity to refine the existing summary (only if needed) with some more context below.
|
||||||
|
\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1.2.and 3.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class SummaryExtractor(Extractor):
|
class SummaryExtractor(Extractor):
|
||||||
"""Summary Extractor, it can extract document summary."""
|
"""Summary Extractor, it can extract document summary."""
|
||||||
@@ -41,6 +48,11 @@ class SummaryExtractor(Extractor):
|
|||||||
if language == "en"
|
if language == "en"
|
||||||
else SUMMARY_PROMPT_TEMPLATE_ZH
|
else SUMMARY_PROMPT_TEMPLATE_ZH
|
||||||
)
|
)
|
||||||
|
self._refine_prompt_template = (
|
||||||
|
REFINE_SUMMARY_TEMPLATE_EN
|
||||||
|
if language == "en"
|
||||||
|
else REFINE_SUMMARY_TEMPLATE_ZH
|
||||||
|
)
|
||||||
self._concurrency_limit_with_llm = concurrency_limit_with_llm
|
self._concurrency_limit_with_llm = concurrency_limit_with_llm
|
||||||
self._max_iteration_with_llm = max_iteration_with_llm
|
self._max_iteration_with_llm = max_iteration_with_llm
|
||||||
self._concurrency_limit_with_llm = concurrency_limit_with_llm
|
self._concurrency_limit_with_llm = concurrency_limit_with_llm
|
||||||
@@ -64,15 +76,23 @@ class SummaryExtractor(Extractor):
|
|||||||
texts = [doc.content for doc in chunks]
|
texts = [doc.content for doc in chunks]
|
||||||
from dbgpt.util.prompt_util import PromptHelper
|
from dbgpt.util.prompt_util import PromptHelper
|
||||||
|
|
||||||
|
# repack chunk into prompt to adapt llm model max context window
|
||||||
prompt_helper = PromptHelper()
|
prompt_helper = PromptHelper()
|
||||||
texts = prompt_helper.repack(
|
texts = prompt_helper.repack(
|
||||||
prompt_template=self._prompt_template, text_chunks=texts
|
prompt_template=self._prompt_template, text_chunks=texts
|
||||||
)
|
)
|
||||||
if len(texts) == 1:
|
if len(texts) == 1:
|
||||||
summary_outs = await self._llm_run_tasks(chunk_texts=texts)
|
summary_outs = await self._llm_run_tasks(
|
||||||
|
chunk_texts=texts, prompt_template=self._refine_prompt_template
|
||||||
|
)
|
||||||
return summary_outs[0]
|
return summary_outs[0]
|
||||||
else:
|
else:
|
||||||
return await self._mapreduce_extract_summary(docs=texts)
|
map_reduce_texts = await self._mapreduce_extract_summary(docs=texts)
|
||||||
|
summary_outs = await self._llm_run_tasks(
|
||||||
|
chunk_texts=[map_reduce_texts],
|
||||||
|
prompt_template=self._refine_prompt_template,
|
||||||
|
)
|
||||||
|
return summary_outs[0]
|
||||||
|
|
||||||
def _extract(self, chunks: List[Chunk]) -> str:
|
def _extract(self, chunks: List[Chunk]) -> str:
|
||||||
"""document extract summary
|
"""document extract summary
|
||||||
@@ -98,7 +118,8 @@ class SummaryExtractor(Extractor):
|
|||||||
return docs[0]
|
return docs[0]
|
||||||
else:
|
else:
|
||||||
summary_outs = await self._llm_run_tasks(
|
summary_outs = await self._llm_run_tasks(
|
||||||
chunk_texts=docs[0 : self._max_iteration_with_llm]
|
chunk_texts=docs[0 : self._max_iteration_with_llm],
|
||||||
|
prompt_template=self._prompt_template,
|
||||||
)
|
)
|
||||||
from dbgpt.util.prompt_util import PromptHelper
|
from dbgpt.util.prompt_util import PromptHelper
|
||||||
|
|
||||||
@@ -108,10 +129,13 @@ class SummaryExtractor(Extractor):
|
|||||||
)
|
)
|
||||||
return await self._mapreduce_extract_summary(docs=summary_outs)
|
return await self._mapreduce_extract_summary(docs=summary_outs)
|
||||||
|
|
||||||
async def _llm_run_tasks(self, chunk_texts: List[str]) -> List[str]:
|
async def _llm_run_tasks(
|
||||||
|
self, chunk_texts: List[str], prompt_template: str
|
||||||
|
) -> List[str]:
|
||||||
"""llm run tasks
|
"""llm run tasks
|
||||||
Args:
|
Args:
|
||||||
chunk_texts: List[str]
|
chunk_texts: List[str]
|
||||||
|
prompt_template: str
|
||||||
Returns:
|
Returns:
|
||||||
summary_outs: List[str]
|
summary_outs: List[str]
|
||||||
"""
|
"""
|
||||||
@@ -119,7 +143,7 @@ class SummaryExtractor(Extractor):
|
|||||||
for chunk_text in chunk_texts:
|
for chunk_text in chunk_texts:
|
||||||
from dbgpt.core import ModelMessage
|
from dbgpt.core import ModelMessage
|
||||||
|
|
||||||
prompt = self._prompt_template.format(context=chunk_text)
|
prompt = prompt_template.format(context=chunk_text)
|
||||||
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
|
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
|
||||||
request = ModelRequest(model=self._model_name, messages=messages)
|
request = ModelRequest(model=self._model_name, messages=messages)
|
||||||
tasks.append(self._llm_client.generate(request))
|
tasks.append(self._llm_client.generate(request))
|
||||||
|
37
dbgpt/rag/operator/db_schema.py
Normal file
37
dbgpt/rag/operator/db_schema.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from dbgpt.core.awel.task.base import IN
|
||||||
|
from dbgpt.core.interface.retriever import RetrieverOperator
|
||||||
|
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||||
|
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||||
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
|
||||||
|
class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||||
|
"""The DBSchema Retriever Operator.
|
||||||
|
Args:
|
||||||
|
connection (RDBMSDatabase): The connection.
|
||||||
|
top_k (int, optional): The top k. Defaults to 4.
|
||||||
|
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
top_k: int = 4,
|
||||||
|
connection: Optional[RDBMSDatabase] = None,
|
||||||
|
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._retriever = DBSchemaRetriever(
|
||||||
|
top_k=top_k,
|
||||||
|
connection=connection,
|
||||||
|
vector_store_connector=vector_store_connector,
|
||||||
|
)
|
||||||
|
|
||||||
|
def retrieve(self, query: IN) -> Any:
|
||||||
|
"""retrieve table schemas.
|
||||||
|
Args:
|
||||||
|
query (IN): query.
|
||||||
|
"""
|
||||||
|
return self._retriever.retrieve(query)
|
39
dbgpt/rag/operator/embedding.py
Normal file
39
dbgpt/rag/operator/embedding.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from functools import reduce
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from dbgpt.core.awel.task.base import IN
|
||||||
|
from dbgpt.core.interface.retriever import RetrieverOperator
|
||||||
|
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||||
|
from dbgpt.rag.retriever.rerank import Ranker
|
||||||
|
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||||
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
top_k: int,
|
||||||
|
score_threshold: Optional[float] = 0.3,
|
||||||
|
query_rewrite: Optional[QueryRewrite] = None,
|
||||||
|
rerank: Ranker = None,
|
||||||
|
vector_store_connector: VectorStoreConnector = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._score_threshold = score_threshold
|
||||||
|
self._retriever = EmbeddingRetriever(
|
||||||
|
top_k=top_k,
|
||||||
|
query_rewrite=query_rewrite,
|
||||||
|
rerank=rerank,
|
||||||
|
vector_store_connector=vector_store_connector,
|
||||||
|
)
|
||||||
|
|
||||||
|
def retrieve(self, query: IN) -> Any:
|
||||||
|
if isinstance(query, str):
|
||||||
|
return self._retriever.retrieve_with_scores(query, self._score_threshold)
|
||||||
|
elif isinstance(query, list):
|
||||||
|
candidates = [
|
||||||
|
self._retriever.retrieve_with_scores(q, self._score_threshold)
|
||||||
|
for q in query
|
||||||
|
]
|
||||||
|
return reduce(lambda x, y: x + y, candidates)
|
26
dbgpt/rag/operator/knowledge.py
Normal file
26
dbgpt/rag/operator/knowledge.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from dbgpt.core.awel import MapOperator
|
||||||
|
from dbgpt.core.awel.task.base import IN
|
||||||
|
from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge
|
||||||
|
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeOperator(MapOperator[Any, Any]):
|
||||||
|
"""Knowledge Operator."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, **kwargs
|
||||||
|
):
|
||||||
|
"""Init the query rewrite operator.
|
||||||
|
Args:
|
||||||
|
knowledge_type: (Optional[KnowledgeType]) The knowledge type.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._knowledge_type = knowledge_type
|
||||||
|
|
||||||
|
async def map(self, datasource: IN) -> Knowledge:
|
||||||
|
"""knowledge operator."""
|
||||||
|
return await self.blocking_func_to_async(
|
||||||
|
KnowledgeFactory.create, datasource, self._knowledge_type
|
||||||
|
)
|
43
dbgpt/rag/operator/rerank.py
Normal file
43
dbgpt/rag/operator/rerank.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import Any, Optional, List
|
||||||
|
|
||||||
|
from dbgpt.core import LLMClient
|
||||||
|
from dbgpt.core.awel import MapOperator
|
||||||
|
from dbgpt.core.awel.task.base import IN
|
||||||
|
from dbgpt.rag.chunk import Chunk
|
||||||
|
from dbgpt.rag.retriever.rerank import DefaultRanker
|
||||||
|
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||||
|
|
||||||
|
|
||||||
|
class RerankOperator(MapOperator[Any, Any]):
|
||||||
|
"""The Rewrite Operator."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
topk: Optional[int] = 3,
|
||||||
|
algorithm: Optional[str] = "default",
|
||||||
|
rank_fn: Optional[callable] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""Init the query rewrite operator.
|
||||||
|
Args:
|
||||||
|
topk (int): The number of the candidates.
|
||||||
|
algorithm (Optional[str]): The rerank algorithm name.
|
||||||
|
rank_fn (Optional[callable]): The rank function.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._algorithm = algorithm
|
||||||
|
self._rerank = DefaultRanker(
|
||||||
|
topk=topk,
|
||||||
|
rank_fn=rank_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def map(self, candidates_with_scores: IN) -> List[Chunk]:
|
||||||
|
"""rerank the candidates.
|
||||||
|
Args:
|
||||||
|
candidates_with_scores (IN): The candidates with scores.
|
||||||
|
Returns:
|
||||||
|
List[Chunk]: The reranked candidates.
|
||||||
|
"""
|
||||||
|
return await self.blocking_func_to_async(
|
||||||
|
self._rerank.rank, candidates_with_scores
|
||||||
|
)
|
41
dbgpt/rag/operator/rewrite.py
Normal file
41
dbgpt/rag/operator/rewrite.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from typing import Any, Optional, List
|
||||||
|
|
||||||
|
from dbgpt.core import LLMClient
|
||||||
|
from dbgpt.core.awel import MapOperator
|
||||||
|
from dbgpt.core.awel.task.base import IN
|
||||||
|
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||||
|
|
||||||
|
|
||||||
|
class QueryRewriteOperator(MapOperator[Any, Any]):
|
||||||
|
"""The Rewrite Operator."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_client: Optional[LLMClient],
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
language: Optional[str] = "en",
|
||||||
|
nums: Optional[int] = 1,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""Init the query rewrite operator.
|
||||||
|
Args:
|
||||||
|
llm_client (Optional[LLMClient]): The LLM client.
|
||||||
|
model_name (Optional[str]): The model name.
|
||||||
|
language (Optional[str]): The prompt language.
|
||||||
|
nums (Optional[int]): The number of the rewrite results.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._nums = nums
|
||||||
|
self._rewrite = QueryRewrite(
|
||||||
|
llm_client=llm_client,
|
||||||
|
model_name=model_name,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def map(self, query_context: IN) -> List[str]:
|
||||||
|
"""Rewrite the query."""
|
||||||
|
query = query_context.get("query")
|
||||||
|
context = query_context.get("context")
|
||||||
|
return await self._rewrite.rewrite(
|
||||||
|
origin_query=query, context=context, nums=self._nums
|
||||||
|
)
|
49
dbgpt/rag/operator/summary.py
Normal file
49
dbgpt/rag/operator/summary.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from dbgpt.core import LLMClient
|
||||||
|
from dbgpt.core.awel.task.base import IN
|
||||||
|
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||||
|
from dbgpt.serve.rag.operators.base import AssemblerOperator
|
||||||
|
|
||||||
|
|
||||||
|
class SummaryAssemblerOperator(AssemblerOperator[Any, Any]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_client: Optional[LLMClient],
|
||||||
|
model_name: Optional[str] = "gpt-3.5-turbo",
|
||||||
|
language: Optional[str] = "en",
|
||||||
|
max_iteration_with_llm: Optional[int] = 5,
|
||||||
|
concurrency_limit_with_llm: Optional[int] = 3,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Init the summary assemble operator.
|
||||||
|
Args:
|
||||||
|
llm_client: (Optional[LLMClient]) The LLM client.
|
||||||
|
model_name: (Optional[str]) The model name.
|
||||||
|
language: (Optional[str]) The prompt language.
|
||||||
|
max_iteration_with_llm: (Optional[int]) The max iteration with llm.
|
||||||
|
concurrency_limit_with_llm: (Optional[int]) The concurrency limit with llm.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._llm_client = llm_client
|
||||||
|
self._model_name = model_name
|
||||||
|
self._language = language
|
||||||
|
self._max_iteration_with_llm = max_iteration_with_llm
|
||||||
|
self._concurrency_limit_with_llm = concurrency_limit_with_llm
|
||||||
|
|
||||||
|
async def map(self, knowledge: IN) -> Any:
|
||||||
|
"""Assemble the summary."""
|
||||||
|
assembler = SummaryAssembler.load_from_knowledge(
|
||||||
|
knowledge=knowledge,
|
||||||
|
llm_client=self._llm_client,
|
||||||
|
model_name=self._model_name,
|
||||||
|
language=self._language,
|
||||||
|
max_iteration_with_llm=self._max_iteration_with_llm,
|
||||||
|
concurrency_limit_with_llm=self._concurrency_limit_with_llm,
|
||||||
|
)
|
||||||
|
return await assembler.generate_summary()
|
||||||
|
|
||||||
|
def assemble(self, knowledge: IN) -> Any:
|
||||||
|
"""assemble knowledge for input value."""
|
||||||
|
pass
|
@@ -1,6 +1,7 @@
|
|||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||||
from dbgpt.util.chat_util import run_async_tasks
|
from dbgpt.util.chat_util import run_async_tasks
|
||||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||||
from dbgpt.rag.chunk import Chunk
|
from dbgpt.rag.chunk import Chunk
|
||||||
@@ -9,14 +10,13 @@ from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
|
|||||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
|
||||||
class DBStructRetriever(BaseRetriever):
|
class DBSchemaRetriever(BaseRetriever):
|
||||||
"""DBStruct retriever."""
|
"""DBSchema retriever."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
top_k: int = 4,
|
top_k: int = 4,
|
||||||
connection: Optional[RDBMSDatabase] = None,
|
connection: Optional[RDBMSDatabase] = None,
|
||||||
is_embeddings: bool = True,
|
|
||||||
query_rewrite: bool = False,
|
query_rewrite: bool = False,
|
||||||
rerank: Ranker = None,
|
rerank: Ranker = None,
|
||||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||||
@@ -26,14 +26,13 @@ class DBStructRetriever(BaseRetriever):
|
|||||||
Args:
|
Args:
|
||||||
top_k (int): top k
|
top_k (int): top k
|
||||||
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
|
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
|
||||||
is_embeddings (bool): Whether to query by embeddings in the vector store, Defaults to True.
|
|
||||||
query_rewrite (bool): query rewrite
|
query_rewrite (bool): query rewrite
|
||||||
rerank (Ranker): rerank
|
rerank (Ranker): rerank
|
||||||
vector_store_connector (VectorStoreConnector): vector store connector
|
vector_store_connector (VectorStoreConnector): vector store connector
|
||||||
code example:
|
code example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
>>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
>>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||||
>>> from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
|
>>> from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||||
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||||
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||||
@@ -71,16 +70,18 @@ class DBStructRetriever(BaseRetriever):
|
|||||||
embedding_fn=embedding_fn
|
embedding_fn=embedding_fn
|
||||||
)
|
)
|
||||||
# get db struct retriever
|
# get db struct retriever
|
||||||
retriever = DBStructRetriever(top_k=3, vector_store_connector=vector_connector)
|
retriever = DBSchemaRetriever(top_k=3, vector_store_connector=vector_connector)
|
||||||
chunks = retriever.retrieve("show columns from table")
|
chunks = retriever.retrieve("show columns from table")
|
||||||
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}")
|
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._top_k = top_k
|
self._top_k = top_k
|
||||||
self._is_embeddings = is_embeddings
|
|
||||||
self._connection = connection
|
self._connection = connection
|
||||||
self._query_rewrite = query_rewrite
|
self._query_rewrite = query_rewrite
|
||||||
self._vector_store_connector = vector_store_connector
|
self._vector_store_connector = vector_store_connector
|
||||||
|
self._need_embeddings = False
|
||||||
|
if self._vector_store_connector:
|
||||||
|
self._need_embeddings = True
|
||||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||||
|
|
||||||
def _retrieve(self, query: str) -> List[Chunk]:
|
def _retrieve(self, query: str) -> List[Chunk]:
|
||||||
@@ -88,7 +89,7 @@ class DBStructRetriever(BaseRetriever):
|
|||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
"""
|
"""
|
||||||
if self._is_embeddings:
|
if self._need_embeddings:
|
||||||
queries = [query]
|
queries = [query]
|
||||||
candidates = [
|
candidates = [
|
||||||
self._vector_store_connector.similar_search(query, self._top_k)
|
self._vector_store_connector.similar_search(query, self._top_k)
|
||||||
@@ -97,8 +98,6 @@ class DBStructRetriever(BaseRetriever):
|
|||||||
candidates = reduce(lambda x, y: x + y, candidates)
|
candidates = reduce(lambda x, y: x + y, candidates)
|
||||||
return candidates
|
return candidates
|
||||||
else:
|
else:
|
||||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
|
||||||
|
|
||||||
table_summaries = _parse_db_summary(self._connection)
|
table_summaries = _parse_db_summary(self._connection)
|
||||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||||
|
|
||||||
@@ -115,7 +114,7 @@ class DBStructRetriever(BaseRetriever):
|
|||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
"""
|
"""
|
||||||
if self._is_embeddings:
|
if self._need_embeddings:
|
||||||
queries = [query]
|
queries = [query]
|
||||||
candidates = [self._similarity_search(query) for query in queries]
|
candidates = [self._similarity_search(query) for query in queries]
|
||||||
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
|
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
|
||||||
@@ -145,7 +144,7 @@ class DBStructRetriever(BaseRetriever):
|
|||||||
self._top_k,
|
self._top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _aparse_db_summary(self) -> List[Chunk]:
|
async def _aparse_db_summary(self) -> List[str]:
|
||||||
"""Similar search."""
|
"""Similar search."""
|
||||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||||
|
|
@@ -99,7 +99,12 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
queries = [query]
|
queries = [query]
|
||||||
if self._query_rewrite:
|
if self._query_rewrite:
|
||||||
new_queries = await self._query_rewrite.rewrite(origin_query=query, nums=1)
|
candidates_tasks = [self._similarity_search(query) for query in queries]
|
||||||
|
chunks = await self._run_async_tasks(candidates_tasks)
|
||||||
|
context = "\n".join([chunk.content for chunk in chunks])
|
||||||
|
new_queries = await self._query_rewrite.rewrite(
|
||||||
|
origin_query=query, context=context, nums=1
|
||||||
|
)
|
||||||
queries.extend(new_queries)
|
queries.extend(new_queries)
|
||||||
candidates = [self._similarity_search(query) for query in queries]
|
candidates = [self._similarity_search(query) for query in queries]
|
||||||
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
|
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
|
||||||
@@ -117,7 +122,12 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
queries = [query]
|
queries = [query]
|
||||||
if self._query_rewrite:
|
if self._query_rewrite:
|
||||||
new_queries = await self._query_rewrite.rewrite(origin_query=query, nums=1)
|
candidates_tasks = [self._similarity_search(query) for query in queries]
|
||||||
|
chunks = await self._run_async_tasks(candidates_tasks)
|
||||||
|
context = "\n".join([chunk.content for chunk in chunks])
|
||||||
|
new_queries = await self._query_rewrite.rewrite(
|
||||||
|
origin_query=query, context=context, nums=1
|
||||||
|
)
|
||||||
queries.extend(new_queries)
|
queries.extend(new_queries)
|
||||||
candidates_with_score = [
|
candidates_with_score = [
|
||||||
self._similarity_search_with_score(query, score_threshold)
|
self._similarity_search_with_score(query, score_threshold)
|
||||||
@@ -137,6 +147,12 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
self._top_k,
|
self._top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _run_async_tasks(self, tasks) -> List[Chunk]:
|
||||||
|
"""Run async tasks."""
|
||||||
|
candidates = await run_async_tasks(tasks=tasks, concurrency_limit=1)
|
||||||
|
candidates = reduce(lambda x, y: x + y, candidates)
|
||||||
|
return candidates
|
||||||
|
|
||||||
async def _similarity_search_with_score(
|
async def _similarity_search_with_score(
|
||||||
self, query, score_threshold
|
self, query, score_threshold
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
|
@@ -2,14 +2,14 @@ from typing import List, Optional
|
|||||||
from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType
|
from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType
|
||||||
|
|
||||||
REWRITE_PROMPT_TEMPLATE_EN = """
|
REWRITE_PROMPT_TEMPLATE_EN = """
|
||||||
Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'\n":
|
Based on the given context {context}, Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'":
|
||||||
"original query:: {original_query}\n"
|
"original query:{original_query}\n"
|
||||||
"queries:\n"
|
"queries:"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
REWRITE_PROMPT_TEMPLATE_ZH = """请根据原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:<queries>':
|
REWRITE_PROMPT_TEMPLATE_ZH = """请根据上下文{context}, 将原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:<queries>'
|
||||||
"original_query:{original_query}\n"
|
"original_query:{original_query}\n"
|
||||||
"queries:\n"
|
"queries:"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -29,6 +29,7 @@ class QueryRewrite:
|
|||||||
- query: (str), user query
|
- query: (str), user query
|
||||||
- model_name: (str), llm model name
|
- model_name: (str), llm model name
|
||||||
- llm_client: (Optional[LLMClient])
|
- llm_client: (Optional[LLMClient])
|
||||||
|
- language: (Optional[str]), language
|
||||||
"""
|
"""
|
||||||
self._model_name = model_name
|
self._model_name = model_name
|
||||||
self._llm_client = llm_client
|
self._llm_client = llm_client
|
||||||
@@ -39,17 +40,22 @@ class QueryRewrite:
|
|||||||
else REWRITE_PROMPT_TEMPLATE_ZH
|
else REWRITE_PROMPT_TEMPLATE_ZH
|
||||||
)
|
)
|
||||||
|
|
||||||
async def rewrite(self, origin_query: str, nums: Optional[int] = 1) -> List[str]:
|
async def rewrite(
|
||||||
|
self, origin_query: str, context: Optional[str], nums: Optional[int] = 1
|
||||||
|
) -> List[str]:
|
||||||
"""query rewrite
|
"""query rewrite
|
||||||
Args:
|
Args:
|
||||||
origin_query: str original query
|
origin_query: str original query
|
||||||
|
context: Optional[str] context
|
||||||
nums: Optional[int] rewrite nums
|
nums: Optional[int] rewrite nums
|
||||||
Returns:
|
Returns:
|
||||||
queries: List[str]
|
queries: List[str]
|
||||||
"""
|
"""
|
||||||
from dbgpt.util.chat_util import run_async_tasks
|
from dbgpt.util.chat_util import run_async_tasks
|
||||||
|
|
||||||
prompt = self._prompt_template.format(original_query=origin_query, nums=nums)
|
prompt = self._prompt_template.format(
|
||||||
|
context=context, original_query=origin_query, nums=nums
|
||||||
|
)
|
||||||
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
|
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
|
||||||
request = ModelRequest(model=self._model_name, messages=messages)
|
request = ModelRequest(model=self._model_name, messages=messages)
|
||||||
tasks = [self._llm_client.generate(request)]
|
tasks = [self._llm_client.generate(request)]
|
||||||
@@ -61,8 +67,12 @@ class QueryRewrite:
|
|||||||
queries,
|
queries,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
print("rewrite queries:", queries)
|
if len(queries) == 0:
|
||||||
return self._parse_llm_output(output=queries[0])
|
print("llm generate no rewrite queries.")
|
||||||
|
return queries
|
||||||
|
new_queries = self._parse_llm_output(output=queries[0])[0:nums]
|
||||||
|
print(f"rewrite queries: {new_queries}")
|
||||||
|
return new_queries
|
||||||
|
|
||||||
def correct(self) -> List[str]:
|
def correct(self) -> List[str]:
|
||||||
pass
|
pass
|
||||||
@@ -81,6 +91,8 @@ class QueryRewrite:
|
|||||||
|
|
||||||
if response.startswith("queries:"):
|
if response.startswith("queries:"):
|
||||||
response = response[len("queries:") :]
|
response = response[len("queries:") :]
|
||||||
|
if response.startswith("queries:"):
|
||||||
|
response = response[len("queries:") :]
|
||||||
|
|
||||||
queries = response.split(",")
|
queries = response.split(",")
|
||||||
if len(queries) == 1:
|
if len(queries) == 1:
|
||||||
@@ -90,6 +102,10 @@ class QueryRewrite:
|
|||||||
if len(queries) == 1:
|
if len(queries) == 1:
|
||||||
queries = response.split("?")
|
queries = response.split("?")
|
||||||
for k in queries:
|
for k in queries:
|
||||||
|
if k.startswith("queries:"):
|
||||||
|
k = k[len("queries:") :]
|
||||||
|
if k.startswith("queries:"):
|
||||||
|
k = response[len("queries:") :]
|
||||||
rk = k
|
rk = k
|
||||||
if lowercase:
|
if lowercase:
|
||||||
rk = rk.lower()
|
rk = rk.lower()
|
||||||
|
@@ -4,7 +4,7 @@ from typing import List
|
|||||||
|
|
||||||
import dbgpt
|
import dbgpt
|
||||||
from dbgpt.rag.chunk import Chunk
|
from dbgpt.rag.chunk import Chunk
|
||||||
from dbgpt.rag.retriever.db_struct import DBStructRetriever
|
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||||
|
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ def mock_vector_store_connector():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
|
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||||
return DBStructRetriever(
|
return DBSchemaRetriever(
|
||||||
connection=mock_db_connection,
|
connection=mock_db_connection,
|
||||||
vector_store_connector=mock_vector_store_connector,
|
vector_store_connector=mock_vector_store_connector,
|
||||||
)
|
)
|
||||||
|
@@ -53,9 +53,9 @@ class DBSummaryClient:
|
|||||||
embedding_fn=self.embeddings,
|
embedding_fn=self.embeddings,
|
||||||
vector_store_config=vector_store_config,
|
vector_store_config=vector_store_config,
|
||||||
)
|
)
|
||||||
from dbgpt.rag.retriever.db_struct import DBStructRetriever
|
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||||
|
|
||||||
retriever = DBStructRetriever(
|
retriever = DBSchemaRetriever(
|
||||||
top_k=topk, vector_store_connector=vector_connector
|
top_k=topk, vector_store_connector=vector_connector
|
||||||
)
|
)
|
||||||
table_docs = retriever.retrieve(query)
|
table_docs = retriever.retrieve(query)
|
||||||
@@ -92,9 +92,9 @@ class DBSummaryClient:
|
|||||||
vector_store_config=vector_store_config,
|
vector_store_config=vector_store_config,
|
||||||
)
|
)
|
||||||
if not vector_connector.vector_name_exists():
|
if not vector_connector.vector_name_exists():
|
||||||
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
|
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||||
|
|
||||||
db_assembler = DBStructAssembler.load_from_connection(
|
db_assembler = DBSchemaAssembler.load_from_connection(
|
||||||
connection=db_summary_client.db, vector_store_connector=vector_connector
|
connection=db_summary_client.db, vector_store_connector=vector_connector
|
||||||
)
|
)
|
||||||
if len(db_assembler.get_chunks()) > 0:
|
if len(db_assembler.get_chunks()) > 0:
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
"""Token splitter."""
|
"""Token splitter."""
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
from dbgpt.util.global_helper import globals_helper
|
from dbgpt.util.global_helper import globals_helper
|
||||||
from dbgpt.util.splitter_utils import split_by_sep, split_by_char
|
from dbgpt.util.splitter_utils import split_by_sep, split_by_char
|
||||||
|
@@ -7,24 +7,24 @@ from dbgpt.rag.chunk_manager import ChunkParameters, ChunkManager
|
|||||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||||
from dbgpt.rag.knowledge.base import Knowledge, ChunkStrategy
|
from dbgpt.rag.knowledge.base import Knowledge, ChunkStrategy
|
||||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||||
from dbgpt.rag.retriever.db_struct import DBStructRetriever
|
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||||
from dbgpt.serve.rag.assembler.base import BaseAssembler
|
from dbgpt.serve.rag.assembler.base import BaseAssembler
|
||||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
|
||||||
class DBStructAssembler(BaseAssembler):
|
class DBSchemaAssembler(BaseAssembler):
|
||||||
"""DBStructAssembler
|
"""DBSchemaAssembler
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||||
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
|
from dbgpt.serve.rag.assembler.db_struct import DBSchemaAssembler
|
||||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||||
|
|
||||||
connection = SQLiteTempConnect.create_temporary_db()
|
connection = SQLiteTempConnect.create_temporary_db()
|
||||||
assembler = DBStructAssembler.load_from_connection(
|
assembler = DBSchemaAssembler.load_from_connection(
|
||||||
connection=connection,
|
connection=connection,
|
||||||
embedding_model=embedding_model_path,
|
embedding_model=embedding_model_path,
|
||||||
)
|
)
|
||||||
@@ -53,18 +53,21 @@ class DBStructAssembler(BaseAssembler):
|
|||||||
"""
|
"""
|
||||||
if connection is None:
|
if connection is None:
|
||||||
raise ValueError("datasource connection must be provided.")
|
raise ValueError("datasource connection must be provided.")
|
||||||
|
self._connection = connection
|
||||||
|
self._vector_store_connector = vector_store_connector
|
||||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||||
|
|
||||||
|
self._embedding_model = embedding_model
|
||||||
|
if self._embedding_model:
|
||||||
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
|
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
|
||||||
default_model_name=os.getenv("EMBEDDING_MODEL")
|
default_model_name=self._embedding_model
|
||||||
)
|
)
|
||||||
self._connection = connection
|
self.embedding_fn = embedding_factory.create(self._embedding_model)
|
||||||
if embedding_model:
|
if self._vector_store_connector.vector_store_config.embedding_fn is None:
|
||||||
embedding_fn = embedding_factory.create(model_name=embedding_model)
|
self._vector_store_connector.vector_store_config.embedding_fn = (
|
||||||
self._vector_store_connector = (
|
self.embedding_fn
|
||||||
vector_store_connector
|
|
||||||
or VectorStoreConnector.from_default(embedding_fn=embedding_fn)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chunk_parameters=chunk_parameters,
|
chunk_parameters=chunk_parameters,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -79,7 +82,7 @@ class DBStructAssembler(BaseAssembler):
|
|||||||
embedding_model: Optional[str] = None,
|
embedding_model: Optional[str] = None,
|
||||||
embedding_factory: Optional[EmbeddingFactory] = None,
|
embedding_factory: Optional[EmbeddingFactory] = None,
|
||||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||||
) -> "DBStructAssembler":
|
) -> "DBSchemaAssembler":
|
||||||
"""Load document embedding into vector store from path.
|
"""Load document embedding into vector store from path.
|
||||||
Args:
|
Args:
|
||||||
connection: (RDBMSDatabase) RDBMSDatabase connection.
|
connection: (RDBMSDatabase) RDBMSDatabase connection.
|
||||||
@@ -89,13 +92,9 @@ class DBStructAssembler(BaseAssembler):
|
|||||||
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
|
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
|
||||||
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
|
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
|
||||||
Returns:
|
Returns:
|
||||||
DBStructAssembler
|
DBSchemaAssembler
|
||||||
"""
|
"""
|
||||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
embedding_factory = embedding_factory
|
||||||
|
|
||||||
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
|
|
||||||
default_model_name=embedding_model or os.getenv("EMBEDDING_MODEL_PATH")
|
|
||||||
)
|
|
||||||
chunk_parameters = chunk_parameters or ChunkParameters(
|
chunk_parameters = chunk_parameters or ChunkParameters(
|
||||||
chunk_strategy=ChunkStrategy.CHUNK_BY_SIZE.name, chunk_overlap=0
|
chunk_strategy=ChunkStrategy.CHUNK_BY_SIZE.name, chunk_overlap=0
|
||||||
)
|
)
|
||||||
@@ -136,14 +135,14 @@ class DBStructAssembler(BaseAssembler):
|
|||||||
def _extract_info(self, chunks) -> List[Chunk]:
|
def _extract_info(self, chunks) -> List[Chunk]:
|
||||||
"""Extract info from chunks."""
|
"""Extract info from chunks."""
|
||||||
|
|
||||||
def as_retriever(self, top_k: Optional[int] = 4) -> DBStructRetriever:
|
def as_retriever(self, top_k: Optional[int] = 4) -> DBSchemaRetriever:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
top_k:(Optional[int]), default 4
|
top_k:(Optional[int]), default 4
|
||||||
Returns:
|
Returns:
|
||||||
DBStructRetriever
|
DBSchemaRetriever
|
||||||
"""
|
"""
|
||||||
return DBStructRetriever(
|
return DBSchemaRetriever(
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
connection=self._connection,
|
connection=self._connection,
|
||||||
is_embeddings=True,
|
is_embeddings=True,
|
@@ -46,16 +46,18 @@ class EmbeddingAssembler(BaseAssembler):
|
|||||||
"""
|
"""
|
||||||
if knowledge is None:
|
if knowledge is None:
|
||||||
raise ValueError("knowledge datasource must be provided.")
|
raise ValueError("knowledge datasource must be provided.")
|
||||||
|
self._vector_store_connector = vector_store_connector
|
||||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||||
|
|
||||||
|
self._embedding_model = embedding_model
|
||||||
|
if self._embedding_model:
|
||||||
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
|
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
|
||||||
default_model_name=os.getenv("EMBEDDING_MODEL")
|
default_model_name=self._embedding_model
|
||||||
)
|
)
|
||||||
if embedding_model:
|
self.embedding_fn = embedding_factory.create(self._embedding_model)
|
||||||
embedding_fn = embedding_factory.create(model_name=embedding_model)
|
if self._vector_store_connector.vector_store_config.embedding_fn is None:
|
||||||
self._vector_store_connector = (
|
self._vector_store_connector.vector_store_config.embedding_fn = (
|
||||||
vector_store_connector
|
self.embedding_fn
|
||||||
or VectorStoreConnector.from_default(embedding_fn=embedding_fn)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@@ -57,9 +57,8 @@ class SummaryAssembler(BaseAssembler):
|
|||||||
from dbgpt.rag.extractor.summary import SummaryExtractor
|
from dbgpt.rag.extractor.summary import SummaryExtractor
|
||||||
|
|
||||||
self._extractor = extractor or SummaryExtractor(
|
self._extractor = extractor or SummaryExtractor(
|
||||||
llm_client=self._llm_client, model_name=self._model_name
|
llm_client=self._llm_client, model_name=self._model_name, language=language
|
||||||
)
|
)
|
||||||
self._language = language
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
knowledge=knowledge,
|
knowledge=knowledge,
|
||||||
chunk_parameters=chunk_parameters,
|
chunk_parameters=chunk_parameters,
|
||||||
|
@@ -7,7 +7,7 @@ from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
|||||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||||
from dbgpt.rag.knowledge.base import Knowledge
|
from dbgpt.rag.knowledge.base import Knowledge
|
||||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||||
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
|
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
|
||||||
@@ -66,7 +66,7 @@ def test_load_knowledge(
|
|||||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||||
assembler = DBStructAssembler(
|
assembler = DBSchemaAssembler(
|
||||||
connection=mock_db_connection,
|
connection=mock_db_connection,
|
||||||
chunk_parameters=mock_chunk_parameters,
|
chunk_parameters=mock_chunk_parameters,
|
||||||
embedding_factory=mock_embedding_factory,
|
embedding_factory=mock_embedding_factory,
|
||||||
|
0
dbgpt/serve/rag/operators/__init__.py
Normal file
0
dbgpt/serve/rag/operators/__init__.py
Normal file
23
dbgpt/serve/rag/operators/base.py
Normal file
23
dbgpt/serve/rag/operators/base.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from dbgpt.core.awel import MapOperator
|
||||||
|
from dbgpt.core.awel.task.base import IN, OUT
|
||||||
|
|
||||||
|
|
||||||
|
class AssemblerOperator(MapOperator[IN, OUT]):
|
||||||
|
"""The Base Assembler Operator."""
|
||||||
|
|
||||||
|
async def map(self, input_value: IN) -> OUT:
|
||||||
|
"""Map input value to output value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_value (IN): The input value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OUT: The output value.
|
||||||
|
"""
|
||||||
|
return await self.blocking_func_to_async(self.assemble, input_value)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def assemble(self, input_value: IN) -> OUT:
|
||||||
|
"""assemble knowledge for input value."""
|
36
dbgpt/serve/rag/operators/db_schema.py
Normal file
36
dbgpt/serve/rag/operators/db_schema.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from dbgpt.core.awel.task.base import IN
|
||||||
|
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||||
|
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||||
|
from dbgpt.serve.rag.operators.base import AssemblerOperator
|
||||||
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
|
||||||
|
class DBSchemaAssemblerOperator(AssemblerOperator[Any, Any]):
|
||||||
|
"""The DBSchema Assembler Operator.
|
||||||
|
Args:
|
||||||
|
connection (RDBMSDatabase): The connection.
|
||||||
|
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None.
|
||||||
|
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection: RDBMSDatabase = None,
|
||||||
|
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self._connection = connection
|
||||||
|
self._vector_store_connector = vector_store_connector
|
||||||
|
self._assembler = DBSchemaAssembler.load_from_connection(
|
||||||
|
connection=self._connection,
|
||||||
|
vector_store_connector=self._vector_store_connector,
|
||||||
|
)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def assemble(self, input_value: IN) -> Any:
|
||||||
|
"""assemble knowledge for input value."""
|
||||||
|
if self._vector_store_connector:
|
||||||
|
self._assembler.persist()
|
||||||
|
return self._assembler.get_chunks()
|
44
dbgpt/serve/rag/operators/embedding.py
Normal file
44
dbgpt/serve/rag/operators/embedding.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from dbgpt.core.awel.task.base import IN
|
||||||
|
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||||
|
from dbgpt.rag.knowledge.base import Knowledge
|
||||||
|
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||||
|
from dbgpt.serve.rag.operators.base import AssemblerOperator
|
||||||
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingAssemblerOperator(AssemblerOperator[Any, Any]):
|
||||||
|
"""The Embedding Assembler Operator.
|
||||||
|
Args:
|
||||||
|
knowledge (Knowledge): The knowledge.
|
||||||
|
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None.
|
||||||
|
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_parameters: Optional[ChunkParameters] = ChunkParameters(
|
||||||
|
chunk_strategy="CHUNK_BY_SIZE"
|
||||||
|
),
|
||||||
|
vector_store_connector: VectorStoreConnector = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE").
|
||||||
|
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
|
||||||
|
"""
|
||||||
|
self._chunk_parameters = chunk_parameters
|
||||||
|
self._vector_store_connector = vector_store_connector
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def assemble(self, knowledge: IN) -> Any:
|
||||||
|
"""assemble knowledge for input value."""
|
||||||
|
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||||
|
knowledge=knowledge,
|
||||||
|
chunk_parameters=self._chunk_parameters,
|
||||||
|
vector_store_connector=self._vector_store_connector,
|
||||||
|
)
|
||||||
|
assembler.persist()
|
||||||
|
return assembler.get_chunks()
|
@@ -92,6 +92,11 @@ class VectorStoreConnector:
|
|||||||
"""
|
"""
|
||||||
return self.client.similar_search_with_scores(doc, topk, score_threshold)
|
return self.client.similar_search_with_scores(doc, topk, score_threshold)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vector_store_config(self) -> VectorStoreConfig:
|
||||||
|
"""vector store config."""
|
||||||
|
return self._vector_store_config
|
||||||
|
|
||||||
def vector_name_exists(self):
|
def vector_name_exists(self):
|
||||||
"""is vector store name exist."""
|
"""is vector store name exist."""
|
||||||
return self.client.vector_name_exists()
|
return self.client.vector_name_exists()
|
||||||
|
130
examples/awel/simple_dbschema_retriever_example.py
Normal file
130
examples/awel/simple_dbschema_retriever_example.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import os
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||||
|
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||||
|
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||||
|
from dbgpt.rag.chunk import Chunk
|
||||||
|
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||||
|
from dbgpt.rag.operator.db_schema import DBSchemaRetrieverOperator
|
||||||
|
from dbgpt.serve.rag.operators.db_schema import DBSchemaAssemblerOperator
|
||||||
|
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||||
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
"""AWEL: Simple rag db schema embedding operator example
|
||||||
|
|
||||||
|
if you not set vector_store_connector, it will return all tables schema in database.
|
||||||
|
```
|
||||||
|
retriever_task = DBSchemaRetrieverOperator(
|
||||||
|
connection=_create_temporary_connection()
|
||||||
|
)
|
||||||
|
```
|
||||||
|
if you set vector_store_connector, it will recall topk similarity tables schema in database.
|
||||||
|
```
|
||||||
|
retriever_task = DBSchemaRetrieverOperator(
|
||||||
|
connection=_create_temporary_connection()
|
||||||
|
top_k=1,
|
||||||
|
vector_store_connector=vector_store_connector
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
..code-block:: shell
|
||||||
|
curl --location 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/dbschema' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{"query": "what is user name?"}'
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _create_vector_connector():
|
||||||
|
"""Create vector connector."""
|
||||||
|
return VectorStoreConnector.from_default(
|
||||||
|
"Chroma",
|
||||||
|
vector_store_config=ChromaVectorConfig(
|
||||||
|
name="vector_name",
|
||||||
|
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||||
|
),
|
||||||
|
embedding_fn=DefaultEmbeddingFactory(
|
||||||
|
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||||
|
).create(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_temporary_connection():
|
||||||
|
"""Create a temporary database connection for testing."""
|
||||||
|
connect = SQLiteTempConnect.create_temporary_db()
|
||||||
|
connect.create_temp_tables(
|
||||||
|
{
|
||||||
|
"user": {
|
||||||
|
"columns": {
|
||||||
|
"id": "INTEGER PRIMARY KEY",
|
||||||
|
"name": "TEXT",
|
||||||
|
"age": "INTEGER",
|
||||||
|
},
|
||||||
|
"data": [
|
||||||
|
(1, "Tom", 10),
|
||||||
|
(2, "Jerry", 16),
|
||||||
|
(3, "Jack", 18),
|
||||||
|
(4, "Alice", 20),
|
||||||
|
(5, "Bob", 22),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return connect
|
||||||
|
|
||||||
|
|
||||||
|
def _join_fn(chunks: List[Chunk], query: str) -> str:
|
||||||
|
print(f"db schema info is {[chunk.content for chunk in chunks]}")
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerReqBody(BaseModel):
|
||||||
|
query: str = Field(..., description="User query")
|
||||||
|
|
||||||
|
|
||||||
|
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def map(self, input_value: TriggerReqBody) -> Dict:
|
||||||
|
params = {
|
||||||
|
"query": input_value.query,
|
||||||
|
}
|
||||||
|
print(f"Receive input value: {input_value}")
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
with DAG("simple_rag_db_schema_example") as dag:
|
||||||
|
trigger = HttpTrigger(
|
||||||
|
"/examples/rag/dbschema", methods="POST", request_body=TriggerReqBody
|
||||||
|
)
|
||||||
|
request_handle_task = RequestHandleOperator()
|
||||||
|
query_operator = MapOperator(lambda request: request["query"])
|
||||||
|
vector_store_connector = _create_vector_connector()
|
||||||
|
assembler_task = DBSchemaAssemblerOperator(
|
||||||
|
connection=_create_temporary_connection(),
|
||||||
|
vector_store_connector=vector_store_connector,
|
||||||
|
)
|
||||||
|
join_operator = JoinOperator(combine_function=_join_fn)
|
||||||
|
retriever_task = DBSchemaRetrieverOperator(
|
||||||
|
connection=_create_temporary_connection(),
|
||||||
|
top_k=1,
|
||||||
|
vector_store_connector=vector_store_connector,
|
||||||
|
)
|
||||||
|
result_parse_task = MapOperator(lambda chunks: [chunk.content for chunk in chunks])
|
||||||
|
trigger >> request_handle_task >> assembler_task >> join_operator
|
||||||
|
trigger >> request_handle_task >> query_operator >> join_operator
|
||||||
|
join_operator >> retriever_task >> result_parse_task
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if dag.leaf_nodes[0].dev_mode:
|
||||||
|
# Development mode, you can run the dag locally for debugging.
|
||||||
|
from dbgpt.core.awel import setup_dev_environment
|
||||||
|
|
||||||
|
setup_dev_environment([dag], port=5555)
|
||||||
|
else:
|
||||||
|
pass
|
86
examples/awel/simple_rag_embedding_example.py
Normal file
86
examples/awel/simple_rag_embedding_example.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||||
|
from dbgpt.core.awel import DAG, InputOperator, MapOperator, SimpleCallDataInputSource
|
||||||
|
from dbgpt.rag.chunk import Chunk
|
||||||
|
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||||
|
from dbgpt.rag.operator.knowledge import KnowledgeOperator
|
||||||
|
from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator
|
||||||
|
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||||
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
"""AWEL: Simple rag embedding operator example
|
||||||
|
|
||||||
|
pre-requirements:
|
||||||
|
set your file path in your example code.
|
||||||
|
Examples:
|
||||||
|
..code-block:: shell
|
||||||
|
python examples/awel/simple_rag_embedding_example.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
|
||||||
|
"""context Join function for JoinOperator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_dict (Dict): context dict
|
||||||
|
chunks (List[Chunk]): chunks
|
||||||
|
Returns:
|
||||||
|
Dict: context dict
|
||||||
|
"""
|
||||||
|
context_dict["context"] = "\n".join([chunk.content for chunk in chunks])
|
||||||
|
return context_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _create_vector_connector():
|
||||||
|
"""Create vector connector."""
|
||||||
|
return VectorStoreConnector.from_default(
|
||||||
|
"Chroma",
|
||||||
|
vector_store_config=ChromaVectorConfig(
|
||||||
|
name="vector_name",
|
||||||
|
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||||
|
),
|
||||||
|
embedding_fn=DefaultEmbeddingFactory(
|
||||||
|
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||||
|
).create(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResultOperator(MapOperator):
|
||||||
|
"""The Result Operator."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def map(self, chunks: List) -> str:
|
||||||
|
result = f"embedding success, there are {len(chunks)} chunks."
|
||||||
|
print(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
with DAG("simple_sdk_rag_embedding_example") as dag:
|
||||||
|
knowledge_operator = KnowledgeOperator()
|
||||||
|
vector_connector = _create_vector_connector()
|
||||||
|
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
||||||
|
file_path_parser = MapOperator(map_function=lambda x: x["file_path"])
|
||||||
|
embedding_operator = EmbeddingAssemblerOperator(
|
||||||
|
vector_store_connector=vector_connector,
|
||||||
|
)
|
||||||
|
output_task = ResultOperator()
|
||||||
|
(
|
||||||
|
input_task
|
||||||
|
>> file_path_parser
|
||||||
|
>> knowledge_operator
|
||||||
|
>> embedding_operator
|
||||||
|
>> output_task
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
input_data = {
|
||||||
|
"data": {
|
||||||
|
"file_path": "docs/docs/awel.md",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output = asyncio.run(output_task.call(call_data=input_data))
|
127
examples/awel/simple_rag_retriever_example.py
Normal file
127
examples/awel/simple_rag_retriever_example.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||||
|
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||||
|
from dbgpt.model import OpenAILLMClient
|
||||||
|
from dbgpt.rag.chunk import Chunk
|
||||||
|
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||||
|
from dbgpt.rag.operator.embedding import EmbeddingRetrieverOperator
|
||||||
|
from dbgpt.rag.operator.rerank import RerankOperator
|
||||||
|
from dbgpt.rag.operator.rewrite import QueryRewriteOperator
|
||||||
|
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||||
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
"""AWEL: Simple rag embedding operator example
|
||||||
|
|
||||||
|
pre-requirements:
|
||||||
|
1. install openai python sdk
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install openai
|
||||||
|
```
|
||||||
|
2. set openai key and base
|
||||||
|
```
|
||||||
|
export OPENAI_API_KEY={your_openai_key}
|
||||||
|
export OPENAI_API_BASE={your_openai_base}
|
||||||
|
```
|
||||||
|
3. make sure you have vector store.
|
||||||
|
if there are no data in vector store, please run examples/awel/simple_rag_embedding_example.py
|
||||||
|
|
||||||
|
|
||||||
|
ensure your embedding model in DB-GPT/models/.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
..code-block:: shell
|
||||||
|
DBGPT_SERVER="http://127.0.0.1:5555"
|
||||||
|
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/rag/retrieve \
|
||||||
|
-H "Content-Type: application/json" -d '{
|
||||||
|
"query": "what is awel talk about?"
|
||||||
|
}'
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerReqBody(BaseModel):
|
||||||
|
query: str = Field(..., description="User query")
|
||||||
|
|
||||||
|
|
||||||
|
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def map(self, input_value: TriggerReqBody) -> Dict:
|
||||||
|
params = {
|
||||||
|
"query": input_value.query,
|
||||||
|
}
|
||||||
|
print(f"Receive input value: {input_value}")
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
|
||||||
|
"""context Join function for JoinOperator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_dict (Dict): context dict
|
||||||
|
chunks (List[Chunk]): chunks
|
||||||
|
Returns:
|
||||||
|
Dict: context dict
|
||||||
|
"""
|
||||||
|
context_dict["context"] = "\n".join([chunk.content for chunk in chunks])
|
||||||
|
return context_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _create_vector_connector():
|
||||||
|
"""Create vector connector."""
|
||||||
|
return VectorStoreConnector.from_default(
|
||||||
|
"Chroma",
|
||||||
|
vector_store_config=ChromaVectorConfig(
|
||||||
|
name="vector_name",
|
||||||
|
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||||
|
),
|
||||||
|
embedding_fn=DefaultEmbeddingFactory(
|
||||||
|
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||||
|
).create(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
with DAG("simple_sdk_rag_retriever_example") as dag:
|
||||||
|
vector_connector = _create_vector_connector()
|
||||||
|
trigger = HttpTrigger(
|
||||||
|
"/examples/rag/retrieve", methods="POST", request_body=TriggerReqBody
|
||||||
|
)
|
||||||
|
request_handle_task = RequestHandleOperator()
|
||||||
|
query_parser = MapOperator(map_function=lambda x: x["query"])
|
||||||
|
context_join_operator = JoinOperator(combine_function=_context_join_fn)
|
||||||
|
rewrite_operator = QueryRewriteOperator(llm_client=OpenAILLMClient())
|
||||||
|
retriever_context_operator = EmbeddingRetrieverOperator(
|
||||||
|
top_k=3,
|
||||||
|
vector_store_connector=vector_connector,
|
||||||
|
)
|
||||||
|
retriever_operator = EmbeddingRetrieverOperator(
|
||||||
|
top_k=3,
|
||||||
|
vector_store_connector=vector_connector,
|
||||||
|
)
|
||||||
|
rerank_operator = RerankOperator()
|
||||||
|
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||||
|
|
||||||
|
trigger >> request_handle_task >> context_join_operator
|
||||||
|
(
|
||||||
|
trigger
|
||||||
|
>> request_handle_task
|
||||||
|
>> query_parser
|
||||||
|
>> retriever_context_operator
|
||||||
|
>> context_join_operator
|
||||||
|
)
|
||||||
|
context_join_operator >> rewrite_operator >> retriever_operator >> rerank_operator
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if dag.leaf_nodes[0].dev_mode:
|
||||||
|
# Development mode, you can run the dag locally for debugging.
|
||||||
|
from dbgpt.core.awel import setup_dev_environment
|
||||||
|
|
||||||
|
setup_dev_environment([dag], port=5555)
|
||||||
|
else:
|
||||||
|
pass
|
74
examples/awel/simple_rag_rewrite_example.py
Normal file
74
examples/awel/simple_rag_rewrite_example.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""AWEL: Simple rag rewrite example
|
||||||
|
|
||||||
|
pre-requirements:
|
||||||
|
1. install openai python sdk
|
||||||
|
```
|
||||||
|
pip install openai
|
||||||
|
```
|
||||||
|
2. set openai key and base
|
||||||
|
```
|
||||||
|
export OPENAI_API_KEY={your_openai_key}
|
||||||
|
export OPENAI_API_BASE={your_openai_base}
|
||||||
|
```
|
||||||
|
or
|
||||||
|
```
|
||||||
|
import os
|
||||||
|
os.environ["OPENAI_API_KEY"] = {your_openai_key}
|
||||||
|
os.environ["OPENAI_API_BASE"] = {your_openai_base}
|
||||||
|
```
|
||||||
|
python examples/awel/simple_rag_rewrite_example.py
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
DBGPT_SERVER="http://127.0.0.1:5000"
|
||||||
|
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/rag/rewrite \
|
||||||
|
-H "Content-Type: application/json" -d '{
|
||||||
|
"query": "compare curry and james",
|
||||||
|
"context":"steve curry and lebron james are nba all-stars"
|
||||||
|
}'
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
|
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||||
|
from dbgpt.model import OpenAILLMClient
|
||||||
|
from dbgpt.rag.operator.rewrite import QueryRewriteOperator
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerReqBody(BaseModel):
|
||||||
|
query: str = Field(..., description="User query")
|
||||||
|
context: str = Field(..., description="context")
|
||||||
|
|
||||||
|
|
||||||
|
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def map(self, input_value: TriggerReqBody) -> Dict:
|
||||||
|
params = {
|
||||||
|
"query": input_value.query,
|
||||||
|
"context": input_value.context,
|
||||||
|
}
|
||||||
|
print(f"Receive input value: {input_value}")
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
with DAG("dbgpt_awel_simple_rag_rewrite_example") as dag:
|
||||||
|
trigger = HttpTrigger(
|
||||||
|
"/examples/rag/rewrite", methods="POST", request_body=TriggerReqBody
|
||||||
|
)
|
||||||
|
request_handle_task = RequestHandleOperator()
|
||||||
|
# build query rewrite operator
|
||||||
|
rewrite_task = QueryRewriteOperator(llm_client=OpenAILLMClient(), nums=2)
|
||||||
|
trigger >> request_handle_task >> rewrite_task
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if dag.leaf_nodes[0].dev_mode:
|
||||||
|
# Development mode, you can run the dag locally for debugging.
|
||||||
|
from dbgpt.core.awel import setup_dev_environment
|
||||||
|
|
||||||
|
setup_dev_environment([dag], port=5555)
|
||||||
|
else:
|
||||||
|
pass
|
84
examples/awel/simple_rag_summary_example.py
Normal file
84
examples/awel/simple_rag_summary_example.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""AWEL:
|
||||||
|
This example shows how to use AWEL to build a simple rag summary example.
|
||||||
|
pre-requirements:
|
||||||
|
1. install openai python sdk
|
||||||
|
```
|
||||||
|
pip install openai
|
||||||
|
```
|
||||||
|
2. set openai key and base
|
||||||
|
```
|
||||||
|
export OPENAI_API_KEY={your_openai_key}
|
||||||
|
export OPENAI_API_BASE={your_openai_base}
|
||||||
|
```
|
||||||
|
or
|
||||||
|
```
|
||||||
|
import os
|
||||||
|
os.environ["OPENAI_API_KEY"] = {your_openai_key}
|
||||||
|
os.environ["OPENAI_API_BASE"] = {your_openai_base}
|
||||||
|
```
|
||||||
|
python examples/awel/simple_rag_summary_example.py
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
DBGPT_SERVER="http://127.0.0.1:5000"
|
||||||
|
FILE_PATH="{your_file_path}"
|
||||||
|
curl -X POST http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/summary \
|
||||||
|
-H "Content-Type: application/json" -d '{
|
||||||
|
"file_path": $FILE_PATH
|
||||||
|
}'
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
|
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||||
|
from dbgpt.model import OpenAILLMClient
|
||||||
|
from dbgpt.rag.operator.knowledge import KnowledgeOperator
|
||||||
|
from dbgpt.rag.operator.summary import SummaryAssemblerOperator
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerReqBody(BaseModel):
|
||||||
|
file_path: str = Field(..., description="file_path")
|
||||||
|
|
||||||
|
|
||||||
|
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def map(self, input_value: TriggerReqBody) -> Dict:
|
||||||
|
params = {
|
||||||
|
"file_path": input_value.file_path,
|
||||||
|
}
|
||||||
|
print(f"Receive input value: {input_value}")
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
with DAG("dbgpt_awel_simple_rag_rewrite_example") as dag:
|
||||||
|
trigger = HttpTrigger(
|
||||||
|
"/examples/rag/summary", methods="POST", request_body=TriggerReqBody
|
||||||
|
)
|
||||||
|
request_handle_task = RequestHandleOperator()
|
||||||
|
path_operator = MapOperator(lambda request: request["file_path"])
|
||||||
|
# build knowledge operator
|
||||||
|
knowledge_operator = KnowledgeOperator()
|
||||||
|
# build summary assembler operator
|
||||||
|
summary_operator = SummaryAssemblerOperator(
|
||||||
|
llm_client=OpenAILLMClient(), language="en"
|
||||||
|
)
|
||||||
|
(
|
||||||
|
trigger
|
||||||
|
>> request_handle_task
|
||||||
|
>> path_operator
|
||||||
|
>> knowledge_operator
|
||||||
|
>> summary_operator
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if dag.leaf_nodes[0].dev_mode:
|
||||||
|
# Development mode, you can run the dag locally for debugging.
|
||||||
|
from dbgpt.core.awel import setup_dev_environment
|
||||||
|
|
||||||
|
setup_dev_environment([dag], port=5555)
|
||||||
|
else:
|
||||||
|
pass
|
@@ -1,6 +1,9 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||||
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
|
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
@@ -13,7 +16,7 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
|||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
..code-block:: shell
|
..code-block:: shell
|
||||||
python examples/rag/db_struct_rag_example.py
|
python examples/rag/db_schema_rag_example.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -41,28 +44,29 @@ def _create_temporary_connection():
|
|||||||
return connect
|
return connect
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def _create_vector_connector():
|
||||||
connection = _create_temporary_connection()
|
"""Create vector connector."""
|
||||||
|
return VectorStoreConnector.from_default(
|
||||||
embedding_model_path = "{your_embedding_model_path}"
|
|
||||||
vector_persist_path = "{your_persist_path}"
|
|
||||||
embedding_fn = DefaultEmbeddingFactory(
|
|
||||||
default_model_name=embedding_model_path
|
|
||||||
).create()
|
|
||||||
vector_connector = VectorStoreConnector.from_default(
|
|
||||||
"Chroma",
|
"Chroma",
|
||||||
vector_store_config=ChromaVectorConfig(
|
vector_store_config=ChromaVectorConfig(
|
||||||
name="vector_name",
|
name="db_schema_vector_store_name",
|
||||||
persist_path=vector_persist_path,
|
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||||
),
|
),
|
||||||
embedding_fn=embedding_fn,
|
embedding_fn=DefaultEmbeddingFactory(
|
||||||
|
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||||
|
).create(),
|
||||||
)
|
)
|
||||||
assembler = DBStructAssembler.load_from_connection(
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
connection = _create_temporary_connection()
|
||||||
|
vector_connector = _create_vector_connector()
|
||||||
|
assembler = DBSchemaAssembler.load_from_connection(
|
||||||
connection=connection,
|
connection=connection,
|
||||||
vector_store_connector=vector_connector,
|
vector_store_connector=vector_connector,
|
||||||
)
|
)
|
||||||
assembler.persist()
|
assembler.persist()
|
||||||
# get db struct retriever
|
# get db schema retriever
|
||||||
retriever = assembler.as_retriever(top_k=1)
|
retriever = assembler.as_retriever(top_k=1)
|
||||||
chunks = retriever.retrieve("show columns from user")
|
chunks = retriever.retrieve("show columns from user")
|
||||||
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}")
|
print(f"db schema rag example results:{[chunk.content for chunk in chunks]}")
|
@@ -1,5 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||||
@@ -20,21 +22,24 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
def _create_vector_connector():
|
||||||
file_path = "./docs/docs/awel.md"
|
"""Create vector connector."""
|
||||||
vector_persist_path = "{your_persist_path}"
|
return VectorStoreConnector.from_default(
|
||||||
embedding_model_path = "{your_embedding_model_path}"
|
|
||||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
|
||||||
vector_connector = VectorStoreConnector.from_default(
|
|
||||||
"Chroma",
|
"Chroma",
|
||||||
vector_store_config=ChromaVectorConfig(
|
vector_store_config=ChromaVectorConfig(
|
||||||
name="vector_name",
|
name="db_schema_vector_store_name",
|
||||||
persist_path=vector_persist_path,
|
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||||
),
|
),
|
||||||
embedding_fn=DefaultEmbeddingFactory(
|
embedding_fn=DefaultEmbeddingFactory(
|
||||||
default_model_name=embedding_model_path
|
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||||
).create(),
|
).create(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
file_path = "docs/docs/awel.md"
|
||||||
|
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||||
|
vector_connector = _create_vector_connector()
|
||||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||||
# get embedding assembler
|
# get embedding assembler
|
||||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||||
@@ -45,7 +50,7 @@ async def main():
|
|||||||
assembler.persist()
|
assembler.persist()
|
||||||
# get embeddings retriever
|
# get embeddings retriever
|
||||||
retriever = assembler.as_retriever(3)
|
retriever = assembler.as_retriever(3)
|
||||||
chunks = await retriever.aretrieve_with_scores("RAG", 0.3)
|
chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3)
|
||||||
print(f"embedding rag example results:{chunks}")
|
print(f"embedding rag example results:{chunks}")
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user