feat(RAG):add rag operators and rag awel examples (#1061)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt
2024-01-13 16:14:48 +08:00
committed by GitHub
parent 99ea6ac1a4
commit a035433170
29 changed files with 1010 additions and 102 deletions

View File

@@ -9,7 +9,7 @@ from dbgpt.util.chat_util import run_async_tasks
SUMMARY_PROMPT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结: SUMMARY_PROMPT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结:
{context} {context}
答案尽量精确和简单,不要过长长度控制在100字左右 答案尽量精确和简单,不要过长长度控制在100字左右, 注意:请用<中文>来进行总结。
""" """
SUMMARY_PROMPT_TEMPLATE_EN = """ SUMMARY_PROMPT_TEMPLATE_EN = """
@@ -18,6 +18,13 @@ Write a quick summary of the following context:
the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters. the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters.
""" """
REFINE_SUMMARY_TEMPLATE_ZH = """我们已经提供了一个到某一点的现有总结:{context}\n 请根据你之前推理的内容进行最终的总结,总结回答的时候最好按照1.2.3.进行. 注意:请用<中文>来进行总结。"""
REFINE_SUMMARY_TEMPLATE_EN = """
We have provided an existing summary up to a certain point: {context}, We have the opportunity to refine the existing summary (only if needed) with some more context below.
\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1.2.and 3.
"""
class SummaryExtractor(Extractor): class SummaryExtractor(Extractor):
"""Summary Extractor, it can extract document summary.""" """Summary Extractor, it can extract document summary."""
@@ -41,6 +48,11 @@ class SummaryExtractor(Extractor):
if language == "en" if language == "en"
else SUMMARY_PROMPT_TEMPLATE_ZH else SUMMARY_PROMPT_TEMPLATE_ZH
) )
self._refine_prompt_template = (
REFINE_SUMMARY_TEMPLATE_EN
if language == "en"
else REFINE_SUMMARY_TEMPLATE_ZH
)
self._concurrency_limit_with_llm = concurrency_limit_with_llm self._concurrency_limit_with_llm = concurrency_limit_with_llm
self._max_iteration_with_llm = max_iteration_with_llm self._max_iteration_with_llm = max_iteration_with_llm
self._concurrency_limit_with_llm = concurrency_limit_with_llm self._concurrency_limit_with_llm = concurrency_limit_with_llm
@@ -64,15 +76,23 @@ class SummaryExtractor(Extractor):
texts = [doc.content for doc in chunks] texts = [doc.content for doc in chunks]
from dbgpt.util.prompt_util import PromptHelper from dbgpt.util.prompt_util import PromptHelper
# repack chunk into prompt to adapt llm model max context window
prompt_helper = PromptHelper() prompt_helper = PromptHelper()
texts = prompt_helper.repack( texts = prompt_helper.repack(
prompt_template=self._prompt_template, text_chunks=texts prompt_template=self._prompt_template, text_chunks=texts
) )
if len(texts) == 1: if len(texts) == 1:
summary_outs = await self._llm_run_tasks(chunk_texts=texts) summary_outs = await self._llm_run_tasks(
chunk_texts=texts, prompt_template=self._refine_prompt_template
)
return summary_outs[0] return summary_outs[0]
else: else:
return await self._mapreduce_extract_summary(docs=texts) map_reduce_texts = await self._mapreduce_extract_summary(docs=texts)
summary_outs = await self._llm_run_tasks(
chunk_texts=[map_reduce_texts],
prompt_template=self._refine_prompt_template,
)
return summary_outs[0]
def _extract(self, chunks: List[Chunk]) -> str: def _extract(self, chunks: List[Chunk]) -> str:
"""document extract summary """document extract summary
@@ -98,7 +118,8 @@ class SummaryExtractor(Extractor):
return docs[0] return docs[0]
else: else:
summary_outs = await self._llm_run_tasks( summary_outs = await self._llm_run_tasks(
chunk_texts=docs[0 : self._max_iteration_with_llm] chunk_texts=docs[0 : self._max_iteration_with_llm],
prompt_template=self._prompt_template,
) )
from dbgpt.util.prompt_util import PromptHelper from dbgpt.util.prompt_util import PromptHelper
@@ -108,10 +129,13 @@ class SummaryExtractor(Extractor):
) )
return await self._mapreduce_extract_summary(docs=summary_outs) return await self._mapreduce_extract_summary(docs=summary_outs)
async def _llm_run_tasks(self, chunk_texts: List[str]) -> List[str]: async def _llm_run_tasks(
self, chunk_texts: List[str], prompt_template: str
) -> List[str]:
"""llm run tasks """llm run tasks
Args: Args:
chunk_texts: List[str] chunk_texts: List[str]
prompt_template: str
Returns: Returns:
summary_outs: List[str] summary_outs: List[str]
""" """
@@ -119,7 +143,7 @@ class SummaryExtractor(Extractor):
for chunk_text in chunk_texts: for chunk_text in chunk_texts:
from dbgpt.core import ModelMessage from dbgpt.core import ModelMessage
prompt = self._prompt_template.format(context=chunk_text) prompt = prompt_template.format(context=chunk_text)
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)] messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
request = ModelRequest(model=self._model_name, messages=messages) request = ModelRequest(model=self._model_name, messages=messages)
tasks.append(self._llm_client.generate(request)) tasks.append(self._llm_client.generate(request))

View File

@@ -0,0 +1,37 @@
from typing import Any, Optional
from dbgpt.core.awel.task.base import IN
from dbgpt.core.interface.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
"""The DBSchema Retriever Operator.
Args:
connection (RDBMSDatabase): The connection.
top_k (int, optional): The top k. Defaults to 4.
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
"""
def __init__(
self,
top_k: int = 4,
connection: Optional[RDBMSDatabase] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
super().__init__(**kwargs)
self._retriever = DBSchemaRetriever(
top_k=top_k,
connection=connection,
vector_store_connector=vector_store_connector,
)
def retrieve(self, query: IN) -> Any:
"""retrieve table schemas.
Args:
query (IN): query.
"""
return self._retriever.retrieve(query)

View File

@@ -0,0 +1,39 @@
from functools import reduce
from typing import Any, Optional
from dbgpt.core.awel.task.base import IN
from dbgpt.core.interface.retriever import RetrieverOperator
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.rag.retriever.rerank import Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
def __init__(
self,
top_k: int,
score_threshold: Optional[float] = 0.3,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Ranker = None,
vector_store_connector: VectorStoreConnector = None,
**kwargs
):
super().__init__(**kwargs)
self._score_threshold = score_threshold
self._retriever = EmbeddingRetriever(
top_k=top_k,
query_rewrite=query_rewrite,
rerank=rerank,
vector_store_connector=vector_store_connector,
)
def retrieve(self, query: IN) -> Any:
if isinstance(query, str):
return self._retriever.retrieve_with_scores(query, self._score_threshold)
elif isinstance(query, list):
candidates = [
self._retriever.retrieve_with_scores(q, self._score_threshold)
for q in query
]
return reduce(lambda x, y: x + y, candidates)

View File

@@ -0,0 +1,26 @@
from typing import Any, List, Optional
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN
from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge
from dbgpt.rag.knowledge.factory import KnowledgeFactory
class KnowledgeOperator(MapOperator[Any, Any]):
"""Knowledge Operator."""
def __init__(
self, knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, **kwargs
):
"""Init the query rewrite operator.
Args:
knowledge_type: (Optional[KnowledgeType]) The knowledge type.
"""
super().__init__(**kwargs)
self._knowledge_type = knowledge_type
async def map(self, datasource: IN) -> Knowledge:
"""knowledge operator."""
return await self.blocking_func_to_async(
KnowledgeFactory.create, datasource, self._knowledge_type
)

View File

@@ -0,0 +1,43 @@
from typing import Any, Optional, List
from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.rerank import DefaultRanker
from dbgpt.rag.retriever.rewrite import QueryRewrite
class RerankOperator(MapOperator[Any, Any]):
"""The Rewrite Operator."""
def __init__(
self,
topk: Optional[int] = 3,
algorithm: Optional[str] = "default",
rank_fn: Optional[callable] = None,
**kwargs
):
"""Init the query rewrite operator.
Args:
topk (int): The number of the candidates.
algorithm (Optional[str]): The rerank algorithm name.
rank_fn (Optional[callable]): The rank function.
"""
super().__init__(**kwargs)
self._algorithm = algorithm
self._rerank = DefaultRanker(
topk=topk,
rank_fn=rank_fn,
)
async def map(self, candidates_with_scores: IN) -> List[Chunk]:
"""rerank the candidates.
Args:
candidates_with_scores (IN): The candidates with scores.
Returns:
List[Chunk]: The reranked candidates.
"""
return await self.blocking_func_to_async(
self._rerank.rank, candidates_with_scores
)

View File

@@ -0,0 +1,41 @@
from typing import Any, Optional, List
from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN
from dbgpt.rag.retriever.rewrite import QueryRewrite
class QueryRewriteOperator(MapOperator[Any, Any]):
"""The Rewrite Operator."""
def __init__(
self,
llm_client: Optional[LLMClient],
model_name: Optional[str] = None,
language: Optional[str] = "en",
nums: Optional[int] = 1,
**kwargs
):
"""Init the query rewrite operator.
Args:
llm_client (Optional[LLMClient]): The LLM client.
model_name (Optional[str]): The model name.
language (Optional[str]): The prompt language.
nums (Optional[int]): The number of the rewrite results.
"""
super().__init__(**kwargs)
self._nums = nums
self._rewrite = QueryRewrite(
llm_client=llm_client,
model_name=model_name,
language=language,
)
async def map(self, query_context: IN) -> List[str]:
"""Rewrite the query."""
query = query_context.get("query")
context = query_context.get("context")
return await self._rewrite.rewrite(
origin_query=query, context=context, nums=self._nums
)

View File

@@ -0,0 +1,49 @@
from typing import Any, Optional
from dbgpt.core import LLMClient
from dbgpt.core.awel.task.base import IN
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
from dbgpt.serve.rag.operators.base import AssemblerOperator
class SummaryAssemblerOperator(AssemblerOperator[Any, Any]):
def __init__(
self,
llm_client: Optional[LLMClient],
model_name: Optional[str] = "gpt-3.5-turbo",
language: Optional[str] = "en",
max_iteration_with_llm: Optional[int] = 5,
concurrency_limit_with_llm: Optional[int] = 3,
**kwargs
):
"""
Init the summary assemble operator.
Args:
llm_client: (Optional[LLMClient]) The LLM client.
model_name: (Optional[str]) The model name.
language: (Optional[str]) The prompt language.
max_iteration_with_llm: (Optional[int]) The max iteration with llm.
concurrency_limit_with_llm: (Optional[int]) The concurrency limit with llm.
"""
super().__init__(**kwargs)
self._llm_client = llm_client
self._model_name = model_name
self._language = language
self._max_iteration_with_llm = max_iteration_with_llm
self._concurrency_limit_with_llm = concurrency_limit_with_llm
async def map(self, knowledge: IN) -> Any:
"""Assemble the summary."""
assembler = SummaryAssembler.load_from_knowledge(
knowledge=knowledge,
llm_client=self._llm_client,
model_name=self._model_name,
language=self._language,
max_iteration_with_llm=self._max_iteration_with_llm,
concurrency_limit_with_llm=self._concurrency_limit_with_llm,
)
return await assembler.generate_summary()
def assemble(self, knowledge: IN) -> Any:
"""assemble knowledge for input value."""
pass

View File

@@ -1,6 +1,7 @@
from functools import reduce from functools import reduce
from typing import List, Optional from typing import List, Optional
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.util.chat_util import run_async_tasks from dbgpt.util.chat_util import run_async_tasks
from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk from dbgpt.rag.chunk import Chunk
@@ -9,14 +10,13 @@ from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.connector import VectorStoreConnector
class DBStructRetriever(BaseRetriever): class DBSchemaRetriever(BaseRetriever):
"""DBStruct retriever.""" """DBSchema retriever."""
def __init__( def __init__(
self, self,
top_k: int = 4, top_k: int = 4,
connection: Optional[RDBMSDatabase] = None, connection: Optional[RDBMSDatabase] = None,
is_embeddings: bool = True,
query_rewrite: bool = False, query_rewrite: bool = False,
rerank: Ranker = None, rerank: Ranker = None,
vector_store_connector: Optional[VectorStoreConnector] = None, vector_store_connector: Optional[VectorStoreConnector] = None,
@@ -26,14 +26,13 @@ class DBStructRetriever(BaseRetriever):
Args: Args:
top_k (int): top k top_k (int): top k
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection. connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
is_embeddings (bool): Whether to query by embeddings in the vector store, Defaults to True.
query_rewrite (bool): query rewrite query_rewrite (bool): query rewrite
rerank (Ranker): rerank rerank (Ranker): rerank
vector_store_connector (VectorStoreConnector): vector store connector vector_store_connector (VectorStoreConnector): vector store connector
code example: code example:
.. code-block:: python .. code-block:: python
>>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect >>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
>>> from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler >>> from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector >>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig >>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever >>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
@@ -71,16 +70,18 @@ class DBStructRetriever(BaseRetriever):
embedding_fn=embedding_fn embedding_fn=embedding_fn
) )
# get db struct retriever # get db struct retriever
retriever = DBStructRetriever(top_k=3, vector_store_connector=vector_connector) retriever = DBSchemaRetriever(top_k=3, vector_store_connector=vector_connector)
chunks = retriever.retrieve("show columns from table") chunks = retriever.retrieve("show columns from table")
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}") print(f"db struct rag example results:{[chunk.content for chunk in chunks]}")
""" """
self._top_k = top_k self._top_k = top_k
self._is_embeddings = is_embeddings
self._connection = connection self._connection = connection
self._query_rewrite = query_rewrite self._query_rewrite = query_rewrite
self._vector_store_connector = vector_store_connector self._vector_store_connector = vector_store_connector
self._need_embeddings = False
if self._vector_store_connector:
self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k) self._rerank = rerank or DefaultRanker(self._top_k)
def _retrieve(self, query: str) -> List[Chunk]: def _retrieve(self, query: str) -> List[Chunk]:
@@ -88,7 +89,7 @@ class DBStructRetriever(BaseRetriever):
Args: Args:
query (str): query text query (str): query text
""" """
if self._is_embeddings: if self._need_embeddings:
queries = [query] queries = [query]
candidates = [ candidates = [
self._vector_store_connector.similar_search(query, self._top_k) self._vector_store_connector.similar_search(query, self._top_k)
@@ -97,8 +98,6 @@ class DBStructRetriever(BaseRetriever):
candidates = reduce(lambda x, y: x + y, candidates) candidates = reduce(lambda x, y: x + y, candidates)
return candidates return candidates
else: else:
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
table_summaries = _parse_db_summary(self._connection) table_summaries = _parse_db_summary(self._connection)
return [Chunk(content=table_summary) for table_summary in table_summaries] return [Chunk(content=table_summary) for table_summary in table_summaries]
@@ -115,7 +114,7 @@ class DBStructRetriever(BaseRetriever):
Args: Args:
query (str): query text query (str): query text
""" """
if self._is_embeddings: if self._need_embeddings:
queries = [query] queries = [query]
candidates = [self._similarity_search(query) for query in queries] candidates = [self._similarity_search(query) for query in queries]
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1) candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
@@ -145,7 +144,7 @@ class DBStructRetriever(BaseRetriever):
self._top_k, self._top_k,
) )
async def _aparse_db_summary(self) -> List[Chunk]: async def _aparse_db_summary(self) -> List[str]:
"""Similar search.""" """Similar search."""
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary

View File

@@ -99,7 +99,12 @@ class EmbeddingRetriever(BaseRetriever):
""" """
queries = [query] queries = [query]
if self._query_rewrite: if self._query_rewrite:
new_queries = await self._query_rewrite.rewrite(origin_query=query, nums=1) candidates_tasks = [self._similarity_search(query) for query in queries]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
new_queries = await self._query_rewrite.rewrite(
origin_query=query, context=context, nums=1
)
queries.extend(new_queries) queries.extend(new_queries)
candidates = [self._similarity_search(query) for query in queries] candidates = [self._similarity_search(query) for query in queries]
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1) candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
@@ -117,7 +122,12 @@ class EmbeddingRetriever(BaseRetriever):
""" """
queries = [query] queries = [query]
if self._query_rewrite: if self._query_rewrite:
new_queries = await self._query_rewrite.rewrite(origin_query=query, nums=1) candidates_tasks = [self._similarity_search(query) for query in queries]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
new_queries = await self._query_rewrite.rewrite(
origin_query=query, context=context, nums=1
)
queries.extend(new_queries) queries.extend(new_queries)
candidates_with_score = [ candidates_with_score = [
self._similarity_search_with_score(query, score_threshold) self._similarity_search_with_score(query, score_threshold)
@@ -137,6 +147,12 @@ class EmbeddingRetriever(BaseRetriever):
self._top_k, self._top_k,
) )
async def _run_async_tasks(self, tasks) -> List[Chunk]:
"""Run async tasks."""
candidates = await run_async_tasks(tasks=tasks, concurrency_limit=1)
candidates = reduce(lambda x, y: x + y, candidates)
return candidates
async def _similarity_search_with_score( async def _similarity_search_with_score(
self, query, score_threshold self, query, score_threshold
) -> List[Chunk]: ) -> List[Chunk]:

View File

@@ -2,14 +2,14 @@ from typing import List, Optional
from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType
REWRITE_PROMPT_TEMPLATE_EN = """ REWRITE_PROMPT_TEMPLATE_EN = """
Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'\n": Based on the given context {context}, Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'":
"original query:: {original_query}\n" "original query:{original_query}\n"
"queries:\n" "queries:"
""" """
REWRITE_PROMPT_TEMPLATE_ZH = """请根据原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries<queries>' REWRITE_PROMPT_TEMPLATE_ZH = """请根据上下文{context}, 将原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:<queries>'
"original_query{original_query}\n" "original_query:{original_query}\n"
"queries\n" "queries:"
""" """
@@ -29,6 +29,7 @@ class QueryRewrite:
- query: (str), user query - query: (str), user query
- model_name: (str), llm model name - model_name: (str), llm model name
- llm_client: (Optional[LLMClient]) - llm_client: (Optional[LLMClient])
- language: (Optional[str]), language
""" """
self._model_name = model_name self._model_name = model_name
self._llm_client = llm_client self._llm_client = llm_client
@@ -39,17 +40,22 @@ class QueryRewrite:
else REWRITE_PROMPT_TEMPLATE_ZH else REWRITE_PROMPT_TEMPLATE_ZH
) )
async def rewrite(self, origin_query: str, nums: Optional[int] = 1) -> List[str]: async def rewrite(
self, origin_query: str, context: Optional[str], nums: Optional[int] = 1
) -> List[str]:
"""query rewrite """query rewrite
Args: Args:
origin_query: str original query origin_query: str original query
context: Optional[str] context
nums: Optional[int] rewrite nums nums: Optional[int] rewrite nums
Returns: Returns:
queries: List[str] queries: List[str]
""" """
from dbgpt.util.chat_util import run_async_tasks from dbgpt.util.chat_util import run_async_tasks
prompt = self._prompt_template.format(original_query=origin_query, nums=nums) prompt = self._prompt_template.format(
context=context, original_query=origin_query, nums=nums
)
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)] messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
request = ModelRequest(model=self._model_name, messages=messages) request = ModelRequest(model=self._model_name, messages=messages)
tasks = [self._llm_client.generate(request)] tasks = [self._llm_client.generate(request)]
@@ -61,8 +67,12 @@ class QueryRewrite:
queries, queries,
) )
) )
print("rewrite queries:", queries) if len(queries) == 0:
return self._parse_llm_output(output=queries[0]) print("llm generate no rewrite queries.")
return queries
new_queries = self._parse_llm_output(output=queries[0])[0:nums]
print(f"rewrite queries: {new_queries}")
return new_queries
def correct(self) -> List[str]: def correct(self) -> List[str]:
pass pass
@@ -81,6 +91,8 @@ class QueryRewrite:
if response.startswith("queries:"): if response.startswith("queries:"):
response = response[len("queries:") :] response = response[len("queries:") :]
if response.startswith("queries"):
response = response[len("queries") :]
queries = response.split(",") queries = response.split(",")
if len(queries) == 1: if len(queries) == 1:
@@ -90,6 +102,10 @@ class QueryRewrite:
if len(queries) == 1: if len(queries) == 1:
queries = response.split("") queries = response.split("")
for k in queries: for k in queries:
if k.startswith("queries:"):
k = k[len("queries:") :]
if k.startswith("queries"):
k = response[len("queries") :]
rk = k rk = k
if lowercase: if lowercase:
rk = rk.lower() rk = rk.lower()

View File

@@ -4,7 +4,7 @@ from typing import List
import dbgpt import dbgpt
from dbgpt.rag.chunk import Chunk from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.db_struct import DBStructRetriever from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
@@ -22,7 +22,7 @@ def mock_vector_store_connector():
@pytest.fixture @pytest.fixture
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector): def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
return DBStructRetriever( return DBSchemaRetriever(
connection=mock_db_connection, connection=mock_db_connection,
vector_store_connector=mock_vector_store_connector, vector_store_connector=mock_vector_store_connector,
) )

View File

@@ -53,9 +53,9 @@ class DBSummaryClient:
embedding_fn=self.embeddings, embedding_fn=self.embeddings,
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
from dbgpt.rag.retriever.db_struct import DBStructRetriever from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
retriever = DBStructRetriever( retriever = DBSchemaRetriever(
top_k=topk, vector_store_connector=vector_connector top_k=topk, vector_store_connector=vector_connector
) )
table_docs = retriever.retrieve(query) table_docs = retriever.retrieve(query)
@@ -92,9 +92,9 @@ class DBSummaryClient:
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
if not vector_connector.vector_name_exists(): if not vector_connector.vector_name_exists():
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
db_assembler = DBStructAssembler.load_from_connection( db_assembler = DBSchemaAssembler.load_from_connection(
connection=db_summary_client.db, vector_store_connector=vector_connector connection=db_summary_client.db, vector_store_connector=vector_connector
) )
if len(db_assembler.get_chunks()) > 0: if len(db_assembler.get_chunks()) > 0:

View File

@@ -1,7 +1,7 @@
"""Token splitter.""" """Token splitter."""
from typing import Callable, List, Optional from typing import Callable, List, Optional
from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel from pydantic import BaseModel, Field, PrivateAttr
from dbgpt.util.global_helper import globals_helper from dbgpt.util.global_helper import globals_helper
from dbgpt.util.splitter_utils import split_by_sep, split_by_char from dbgpt.util.splitter_utils import split_by_sep, split_by_char

View File

@@ -7,24 +7,24 @@ from dbgpt.rag.chunk_manager import ChunkParameters, ChunkManager
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge, ChunkStrategy from dbgpt.rag.knowledge.base import Knowledge, ChunkStrategy
from dbgpt.rag.knowledge.factory import KnowledgeFactory from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.rag.retriever.db_struct import DBStructRetriever from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.serve.rag.assembler.base import BaseAssembler from dbgpt.serve.rag.assembler.base import BaseAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.connector import VectorStoreConnector
class DBStructAssembler(BaseAssembler): class DBSchemaAssembler(BaseAssembler):
"""DBStructAssembler """DBSchemaAssembler
Example: Example:
.. code-block:: python .. code-block:: python
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler from dbgpt.serve.rag.assembler.db_struct import DBSchemaAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
connection = SQLiteTempConnect.create_temporary_db() connection = SQLiteTempConnect.create_temporary_db()
assembler = DBStructAssembler.load_from_connection( assembler = DBSchemaAssembler.load_from_connection(
connection=connection, connection=connection,
embedding_model=embedding_model_path, embedding_model=embedding_model_path,
) )
@@ -53,18 +53,21 @@ class DBStructAssembler(BaseAssembler):
""" """
if connection is None: if connection is None:
raise ValueError("datasource connection must be provided.") raise ValueError("datasource connection must be provided.")
self._connection = connection
self._vector_store_connector = vector_store_connector
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
self._embedding_model = embedding_model
if self._embedding_model:
embedding_factory = embedding_factory or DefaultEmbeddingFactory( embedding_factory = embedding_factory or DefaultEmbeddingFactory(
default_model_name=os.getenv("EMBEDDING_MODEL") default_model_name=self._embedding_model
) )
self._connection = connection self.embedding_fn = embedding_factory.create(self._embedding_model)
if embedding_model: if self._vector_store_connector.vector_store_config.embedding_fn is None:
embedding_fn = embedding_factory.create(model_name=embedding_model) self._vector_store_connector.vector_store_config.embedding_fn = (
self._vector_store_connector = ( self.embedding_fn
vector_store_connector
or VectorStoreConnector.from_default(embedding_fn=embedding_fn)
) )
super().__init__( super().__init__(
chunk_parameters=chunk_parameters, chunk_parameters=chunk_parameters,
**kwargs, **kwargs,
@@ -79,7 +82,7 @@ class DBStructAssembler(BaseAssembler):
embedding_model: Optional[str] = None, embedding_model: Optional[str] = None,
embedding_factory: Optional[EmbeddingFactory] = None, embedding_factory: Optional[EmbeddingFactory] = None,
vector_store_connector: Optional[VectorStoreConnector] = None, vector_store_connector: Optional[VectorStoreConnector] = None,
) -> "DBStructAssembler": ) -> "DBSchemaAssembler":
"""Load document embedding into vector store from path. """Load document embedding into vector store from path.
Args: Args:
connection: (RDBMSDatabase) RDBMSDatabase connection. connection: (RDBMSDatabase) RDBMSDatabase connection.
@@ -89,13 +92,9 @@ class DBStructAssembler(BaseAssembler):
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use. embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use. vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
Returns: Returns:
DBStructAssembler DBSchemaAssembler
""" """
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory embedding_factory = embedding_factory
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
default_model_name=embedding_model or os.getenv("EMBEDDING_MODEL_PATH")
)
chunk_parameters = chunk_parameters or ChunkParameters( chunk_parameters = chunk_parameters or ChunkParameters(
chunk_strategy=ChunkStrategy.CHUNK_BY_SIZE.name, chunk_overlap=0 chunk_strategy=ChunkStrategy.CHUNK_BY_SIZE.name, chunk_overlap=0
) )
@@ -136,14 +135,14 @@ class DBStructAssembler(BaseAssembler):
def _extract_info(self, chunks) -> List[Chunk]: def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks.""" """Extract info from chunks."""
def as_retriever(self, top_k: Optional[int] = 4) -> DBStructRetriever: def as_retriever(self, top_k: Optional[int] = 4) -> DBSchemaRetriever:
""" """
Args: Args:
top_k:(Optional[int]), default 4 top_k:(Optional[int]), default 4
Returns: Returns:
DBStructRetriever DBSchemaRetriever
""" """
return DBStructRetriever( return DBSchemaRetriever(
top_k=top_k, top_k=top_k,
connection=self._connection, connection=self._connection,
is_embeddings=True, is_embeddings=True,

View File

@@ -46,16 +46,18 @@ class EmbeddingAssembler(BaseAssembler):
""" """
if knowledge is None: if knowledge is None:
raise ValueError("knowledge datasource must be provided.") raise ValueError("knowledge datasource must be provided.")
self._vector_store_connector = vector_store_connector
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
self._embedding_model = embedding_model
if self._embedding_model:
embedding_factory = embedding_factory or DefaultEmbeddingFactory( embedding_factory = embedding_factory or DefaultEmbeddingFactory(
default_model_name=os.getenv("EMBEDDING_MODEL") default_model_name=self._embedding_model
) )
if embedding_model: self.embedding_fn = embedding_factory.create(self._embedding_model)
embedding_fn = embedding_factory.create(model_name=embedding_model) if self._vector_store_connector.vector_store_config.embedding_fn is None:
self._vector_store_connector = ( self._vector_store_connector.vector_store_config.embedding_fn = (
vector_store_connector self.embedding_fn
or VectorStoreConnector.from_default(embedding_fn=embedding_fn)
) )
super().__init__( super().__init__(

View File

@@ -57,9 +57,8 @@ class SummaryAssembler(BaseAssembler):
from dbgpt.rag.extractor.summary import SummaryExtractor from dbgpt.rag.extractor.summary import SummaryExtractor
self._extractor = extractor or SummaryExtractor( self._extractor = extractor or SummaryExtractor(
llm_client=self._llm_client, model_name=self._model_name llm_client=self._llm_client, model_name=self._model_name, language=language
) )
self._language = language
super().__init__( super().__init__(
knowledge=knowledge, knowledge=knowledge,
chunk_parameters=chunk_parameters, chunk_parameters=chunk_parameters,

View File

@@ -7,7 +7,7 @@ from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -66,7 +66,7 @@ def test_load_knowledge(
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE" mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter() mock_chunk_parameters.text_splitter = CharacterTextSplitter()
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = DBStructAssembler( assembler = DBSchemaAssembler(
connection=mock_db_connection, connection=mock_db_connection,
chunk_parameters=mock_chunk_parameters, chunk_parameters=mock_chunk_parameters,
embedding_factory=mock_embedding_factory, embedding_factory=mock_embedding_factory,

View File

View File

@@ -0,0 +1,23 @@
from abc import abstractmethod
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN, OUT
class AssemblerOperator(MapOperator[IN, OUT]):
"""The Base Assembler Operator."""
async def map(self, input_value: IN) -> OUT:
"""Map input value to output value.
Args:
input_value (IN): The input value.
Returns:
OUT: The output value.
"""
return await self.blocking_func_to_async(self.assemble, input_value)
@abstractmethod
def assemble(self, input_value: IN) -> OUT:
"""assemble knowledge for input value."""

View File

@@ -0,0 +1,36 @@
from typing import Any, Optional
from dbgpt.core.awel.task.base import IN
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.serve.rag.operators.base import AssemblerOperator
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class DBSchemaAssemblerOperator(AssemblerOperator[Any, Any]):
"""The DBSchema Assembler Operator.
Args:
connection (RDBMSDatabase): The connection.
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None.
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
"""
def __init__(
self,
connection: RDBMSDatabase = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
self._connection = connection
self._vector_store_connector = vector_store_connector
self._assembler = DBSchemaAssembler.load_from_connection(
connection=self._connection,
vector_store_connector=self._vector_store_connector,
)
super().__init__(**kwargs)
def assemble(self, input_value: IN) -> Any:
"""assemble knowledge for input value."""
if self._vector_store_connector:
self._assembler.persist()
return self._assembler.get_chunks()

View File

@@ -0,0 +1,44 @@
from typing import Any, Optional
from dbgpt.core.awel.task.base import IN
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.serve.rag.operators.base import AssemblerOperator
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class EmbeddingAssemblerOperator(AssemblerOperator[Any, Any]):
"""The Embedding Assembler Operator.
Args:
knowledge (Knowledge): The knowledge.
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None.
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
"""
def __init__(
self,
chunk_parameters: Optional[ChunkParameters] = ChunkParameters(
chunk_strategy="CHUNK_BY_SIZE"
),
vector_store_connector: VectorStoreConnector = None,
**kwargs
):
"""
Args:
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE").
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
"""
self._chunk_parameters = chunk_parameters
self._vector_store_connector = vector_store_connector
super().__init__(**kwargs)
def assemble(self, knowledge: IN) -> Any:
"""assemble knowledge for input value."""
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=self._chunk_parameters,
vector_store_connector=self._vector_store_connector,
)
assembler.persist()
return assembler.get_chunks()

View File

@@ -92,6 +92,11 @@ class VectorStoreConnector:
""" """
return self.client.similar_search_with_scores(doc, topk, score_threshold) return self.client.similar_search_with_scores(doc, topk, score_threshold)
@property
def vector_store_config(self) -> VectorStoreConfig:
"""vector store config."""
return self._vector_store_config
def vector_name_exists(self): def vector_name_exists(self):
"""is vector store name exist.""" """is vector store name exist."""
return self.client.vector_name_exists() return self.client.vector_name_exists()

View File

@@ -0,0 +1,130 @@
import os
from typing import Dict, List
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.operator.db_schema import DBSchemaRetrieverOperator
from dbgpt.serve.rag.operators.db_schema import DBSchemaAssemblerOperator
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
"""AWEL: Simple rag db schema embedding operator example
if you not set vector_store_connector, it will return all tables schema in database.
```
retriever_task = DBSchemaRetrieverOperator(
connection=_create_temporary_connection()
)
```
if you set vector_store_connector, it will recall topk similarity tables schema in database.
```
retriever_task = DBSchemaRetrieverOperator(
connection=_create_temporary_connection()
top_k=1,
vector_store_connector=vector_store_connector
)
```
Examples:
..code-block:: shell
curl --location 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/dbschema' \
--header 'Content-Type: application/json' \
--data '{"query": "what is user name?"}'
"""
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="vector_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
def _create_temporary_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnect.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
return connect
def _join_fn(chunks: List[Chunk], query: str) -> str:
print(f"db schema info is {[chunk.content for chunk in chunks]}")
return query
class TriggerReqBody(BaseModel):
query: str = Field(..., description="User query")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
params = {
"query": input_value.query,
}
print(f"Receive input value: {input_value}")
return params
with DAG("simple_rag_db_schema_example") as dag:
trigger = HttpTrigger(
"/examples/rag/dbschema", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
query_operator = MapOperator(lambda request: request["query"])
vector_store_connector = _create_vector_connector()
assembler_task = DBSchemaAssemblerOperator(
connection=_create_temporary_connection(),
vector_store_connector=vector_store_connector,
)
join_operator = JoinOperator(combine_function=_join_fn)
retriever_task = DBSchemaRetrieverOperator(
connection=_create_temporary_connection(),
top_k=1,
vector_store_connector=vector_store_connector,
)
result_parse_task = MapOperator(lambda chunks: [chunk.content for chunk in chunks])
trigger >> request_handle_task >> assembler_task >> join_operator
trigger >> request_handle_task >> query_operator >> join_operator
join_operator >> retriever_task >> result_parse_task
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
pass

View File

@@ -0,0 +1,86 @@
import asyncio
import os
from typing import Dict, List
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, InputOperator, MapOperator, SimpleCallDataInputSource
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.operator.knowledge import KnowledgeOperator
from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
"""AWEL: Simple rag embedding operator example
pre-requirements:
set your file path in your example code.
Examples:
..code-block:: shell
python examples/awel/simple_rag_embedding_example.py
"""
def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
"""context Join function for JoinOperator.
Args:
context_dict (Dict): context dict
chunks (List[Chunk]): chunks
Returns:
Dict: context dict
"""
context_dict["context"] = "\n".join([chunk.content for chunk in chunks])
return context_dict
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="vector_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
class ResultOperator(MapOperator):
"""The Result Operator."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, chunks: List) -> str:
result = f"embedding success, there are {len(chunks)} chunks."
print(result)
return result
with DAG("simple_sdk_rag_embedding_example") as dag:
knowledge_operator = KnowledgeOperator()
vector_connector = _create_vector_connector()
input_task = InputOperator(input_source=SimpleCallDataInputSource())
file_path_parser = MapOperator(map_function=lambda x: x["file_path"])
embedding_operator = EmbeddingAssemblerOperator(
vector_store_connector=vector_connector,
)
output_task = ResultOperator()
(
input_task
>> file_path_parser
>> knowledge_operator
>> embedding_operator
>> output_task
)
if __name__ == "__main__":
input_data = {
"data": {
"file_path": "docs/docs/awel.md",
}
}
output = asyncio.run(output_task.call(call_data=input_data))

View File

@@ -0,0 +1,127 @@
import asyncio
import os
from typing import Dict, List
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.model import OpenAILLMClient
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.operator.embedding import EmbeddingRetrieverOperator
from dbgpt.rag.operator.rerank import RerankOperator
from dbgpt.rag.operator.rewrite import QueryRewriteOperator
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
"""AWEL: Simple rag embedding operator example
pre-requirements:
1. install openai python sdk
```
pip install openai
```
2. set openai key and base
```
export OPENAI_API_KEY={your_openai_key}
export OPENAI_API_BASE={your_openai_base}
```
3. make sure you have vector store.
if there are no data in vector store, please run examples/awel/simple_rag_embedding_example.py
ensure your embedding model in DB-GPT/models/.
Examples:
..code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5555"
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/rag/retrieve \
-H "Content-Type: application/json" -d '{
"query": "what is awel talk about?"
}'
"""
class TriggerReqBody(BaseModel):
query: str = Field(..., description="User query")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
params = {
"query": input_value.query,
}
print(f"Receive input value: {input_value}")
return params
def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
"""context Join function for JoinOperator.
Args:
context_dict (Dict): context dict
chunks (List[Chunk]): chunks
Returns:
Dict: context dict
"""
context_dict["context"] = "\n".join([chunk.content for chunk in chunks])
return context_dict
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="vector_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
with DAG("simple_sdk_rag_retriever_example") as dag:
vector_connector = _create_vector_connector()
trigger = HttpTrigger(
"/examples/rag/retrieve", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
query_parser = MapOperator(map_function=lambda x: x["query"])
context_join_operator = JoinOperator(combine_function=_context_join_fn)
rewrite_operator = QueryRewriteOperator(llm_client=OpenAILLMClient())
retriever_context_operator = EmbeddingRetrieverOperator(
top_k=3,
vector_store_connector=vector_connector,
)
retriever_operator = EmbeddingRetrieverOperator(
top_k=3,
vector_store_connector=vector_connector,
)
rerank_operator = RerankOperator()
model_parse_task = MapOperator(lambda out: out.to_dict())
trigger >> request_handle_task >> context_join_operator
(
trigger
>> request_handle_task
>> query_parser
>> retriever_context_operator
>> context_join_operator
)
context_join_operator >> rewrite_operator >> retriever_operator >> rerank_operator
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
pass

View File

@@ -0,0 +1,74 @@
"""AWEL: Simple rag rewrite example
pre-requirements:
1. install openai python sdk
```
pip install openai
```
2. set openai key and base
```
export OPENAI_API_KEY={your_openai_key}
export OPENAI_API_BASE={your_openai_base}
```
or
```
import os
os.environ["OPENAI_API_KEY"] = {your_openai_key}
os.environ["OPENAI_API_BASE"] = {your_openai_base}
```
python examples/awel/simple_rag_rewrite_example.py
Example:
.. code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5000"
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/rag/rewrite \
-H "Content-Type: application/json" -d '{
"query": "compare curry and james",
"context":"steve curry and lebron james are nba all-stars"
}'
"""
from typing import Dict
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.model import OpenAILLMClient
from dbgpt.rag.operator.rewrite import QueryRewriteOperator
class TriggerReqBody(BaseModel):
query: str = Field(..., description="User query")
context: str = Field(..., description="context")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
params = {
"query": input_value.query,
"context": input_value.context,
}
print(f"Receive input value: {input_value}")
return params
with DAG("dbgpt_awel_simple_rag_rewrite_example") as dag:
trigger = HttpTrigger(
"/examples/rag/rewrite", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
# build query rewrite operator
rewrite_task = QueryRewriteOperator(llm_client=OpenAILLMClient(), nums=2)
trigger >> request_handle_task >> rewrite_task
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
pass

View File

@@ -0,0 +1,84 @@
"""AWEL:
This example shows how to use AWEL to build a simple rag summary example.
pre-requirements:
1. install openai python sdk
```
pip install openai
```
2. set openai key and base
```
export OPENAI_API_KEY={your_openai_key}
export OPENAI_API_BASE={your_openai_base}
```
or
```
import os
os.environ["OPENAI_API_KEY"] = {your_openai_key}
os.environ["OPENAI_API_BASE"] = {your_openai_base}
```
python examples/awel/simple_rag_summary_example.py
Example:
.. code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5000"
FILE_PATH="{your_file_path}"
curl -X POST http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/summary \
-H "Content-Type: application/json" -d '{
"file_path": $FILE_PATH
}'
"""
from typing import Dict
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.model import OpenAILLMClient
from dbgpt.rag.operator.knowledge import KnowledgeOperator
from dbgpt.rag.operator.summary import SummaryAssemblerOperator
class TriggerReqBody(BaseModel):
file_path: str = Field(..., description="file_path")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
params = {
"file_path": input_value.file_path,
}
print(f"Receive input value: {input_value}")
return params
with DAG("dbgpt_awel_simple_rag_rewrite_example") as dag:
trigger = HttpTrigger(
"/examples/rag/summary", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
path_operator = MapOperator(lambda request: request["file_path"])
# build knowledge operator
knowledge_operator = KnowledgeOperator()
# build summary assembler operator
summary_operator = SummaryAssemblerOperator(
llm_client=OpenAILLMClient(), language="en"
)
(
trigger
>> request_handle_task
>> path_operator
>> knowledge_operator
>> summary_operator
)
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
pass

View File

@@ -1,6 +1,9 @@
import os
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -13,7 +16,7 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
Examples: Examples:
..code-block:: shell ..code-block:: shell
python examples/rag/db_struct_rag_example.py python examples/rag/db_schema_rag_example.py
""" """
@@ -41,28 +44,29 @@ def _create_temporary_connection():
return connect return connect
if __name__ == "__main__": def _create_vector_connector():
connection = _create_temporary_connection() """Create vector connector."""
return VectorStoreConnector.from_default(
embedding_model_path = "{your_embedding_model_path}"
vector_persist_path = "{your_persist_path}"
embedding_fn = DefaultEmbeddingFactory(
default_model_name=embedding_model_path
).create()
vector_connector = VectorStoreConnector.from_default(
"Chroma", "Chroma",
vector_store_config=ChromaVectorConfig( vector_store_config=ChromaVectorConfig(
name="vector_name", name="db_schema_vector_store_name",
persist_path=vector_persist_path, persist_path=os.path.join(PILOT_PATH, "data"),
), ),
embedding_fn=embedding_fn, embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
) )
assembler = DBStructAssembler.load_from_connection(
if __name__ == "__main__":
connection = _create_temporary_connection()
vector_connector = _create_vector_connector()
assembler = DBSchemaAssembler.load_from_connection(
connection=connection, connection=connection,
vector_store_connector=vector_connector, vector_store_connector=vector_connector,
) )
assembler.persist() assembler.persist()
# get db struct retriever # get db schema retriever
retriever = assembler.as_retriever(top_k=1) retriever = assembler.as_retriever(top_k=1)
chunks = retriever.retrieve("show columns from user") chunks = retriever.retrieve("show columns from user")
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}") print(f"db schema rag example results:{[chunk.content for chunk in chunks]}")

View File

@@ -1,5 +1,7 @@
import asyncio import asyncio
import os
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.rag.chunk_manager import ChunkParameters from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.knowledge.factory import KnowledgeFactory from dbgpt.rag.knowledge.factory import KnowledgeFactory
@@ -20,21 +22,24 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
""" """
async def main(): def _create_vector_connector():
file_path = "./docs/docs/awel.md" """Create vector connector."""
vector_persist_path = "{your_persist_path}" return VectorStoreConnector.from_default(
embedding_model_path = "{your_embedding_model_path}"
knowledge = KnowledgeFactory.from_file_path(file_path)
vector_connector = VectorStoreConnector.from_default(
"Chroma", "Chroma",
vector_store_config=ChromaVectorConfig( vector_store_config=ChromaVectorConfig(
name="vector_name", name="db_schema_vector_store_name",
persist_path=vector_persist_path, persist_path=os.path.join(PILOT_PATH, "data"),
), ),
embedding_fn=DefaultEmbeddingFactory( embedding_fn=DefaultEmbeddingFactory(
default_model_name=embedding_model_path default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(), ).create(),
) )
async def main():
file_path = "docs/docs/awel.md"
knowledge = KnowledgeFactory.from_file_path(file_path)
vector_connector = _create_vector_connector()
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE") chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
# get embedding assembler # get embedding assembler
assembler = EmbeddingAssembler.load_from_knowledge( assembler = EmbeddingAssembler.load_from_knowledge(
@@ -45,7 +50,7 @@ async def main():
assembler.persist() assembler.persist()
# get embeddings retriever # get embeddings retriever
retriever = assembler.as_retriever(3) retriever = assembler.as_retriever(3)
chunks = await retriever.aretrieve_with_scores("RAG", 0.3) chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3)
print(f"embedding rag example results:{chunks}") print(f"embedding rag example results:{chunks}")