feat(RAG):add rag operators and rag awel examples (#1061)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt
2024-01-13 16:14:48 +08:00
committed by GitHub
parent 99ea6ac1a4
commit a035433170
29 changed files with 1010 additions and 102 deletions

View File

@@ -9,7 +9,7 @@ from dbgpt.util.chat_util import run_async_tasks
SUMMARY_PROMPT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结:
{context}
答案尽量精确和简单,不要过长长度控制在100字左右
答案尽量精确和简单,不要过长长度控制在100字左右, 注意:请用<中文>来进行总结。
"""
SUMMARY_PROMPT_TEMPLATE_EN = """
@@ -18,6 +18,13 @@ Write a quick summary of the following context:
the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters.
"""
REFINE_SUMMARY_TEMPLATE_ZH = """我们已经提供了一个到某一点的现有总结:{context}\n 请根据你之前推理的内容进行最终的总结,总结回答的时候最好按照1.2.3.进行. 注意:请用<中文>来进行总结。"""
REFINE_SUMMARY_TEMPLATE_EN = """
We have provided an existing summary up to a certain point: {context}, We have the opportunity to refine the existing summary (only if needed) with some more context below.
\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1.2.and 3.
"""
class SummaryExtractor(Extractor):
"""Summary Extractor, it can extract document summary."""
@@ -41,6 +48,11 @@ class SummaryExtractor(Extractor):
if language == "en"
else SUMMARY_PROMPT_TEMPLATE_ZH
)
self._refine_prompt_template = (
REFINE_SUMMARY_TEMPLATE_EN
if language == "en"
else REFINE_SUMMARY_TEMPLATE_ZH
)
self._concurrency_limit_with_llm = concurrency_limit_with_llm
self._max_iteration_with_llm = max_iteration_with_llm
self._concurrency_limit_with_llm = concurrency_limit_with_llm
@@ -64,15 +76,23 @@ class SummaryExtractor(Extractor):
texts = [doc.content for doc in chunks]
from dbgpt.util.prompt_util import PromptHelper
# repack chunk into prompt to adapt llm model max context window
prompt_helper = PromptHelper()
texts = prompt_helper.repack(
prompt_template=self._prompt_template, text_chunks=texts
)
if len(texts) == 1:
summary_outs = await self._llm_run_tasks(chunk_texts=texts)
summary_outs = await self._llm_run_tasks(
chunk_texts=texts, prompt_template=self._refine_prompt_template
)
return summary_outs[0]
else:
return await self._mapreduce_extract_summary(docs=texts)
map_reduce_texts = await self._mapreduce_extract_summary(docs=texts)
summary_outs = await self._llm_run_tasks(
chunk_texts=[map_reduce_texts],
prompt_template=self._refine_prompt_template,
)
return summary_outs[0]
def _extract(self, chunks: List[Chunk]) -> str:
"""document extract summary
@@ -98,7 +118,8 @@ class SummaryExtractor(Extractor):
return docs[0]
else:
summary_outs = await self._llm_run_tasks(
chunk_texts=docs[0 : self._max_iteration_with_llm]
chunk_texts=docs[0 : self._max_iteration_with_llm],
prompt_template=self._prompt_template,
)
from dbgpt.util.prompt_util import PromptHelper
@@ -108,10 +129,13 @@ class SummaryExtractor(Extractor):
)
return await self._mapreduce_extract_summary(docs=summary_outs)
async def _llm_run_tasks(self, chunk_texts: List[str]) -> List[str]:
async def _llm_run_tasks(
self, chunk_texts: List[str], prompt_template: str
) -> List[str]:
"""llm run tasks
Args:
chunk_texts: List[str]
prompt_template: str
Returns:
summary_outs: List[str]
"""
@@ -119,7 +143,7 @@ class SummaryExtractor(Extractor):
for chunk_text in chunk_texts:
from dbgpt.core import ModelMessage
prompt = self._prompt_template.format(context=chunk_text)
prompt = prompt_template.format(context=chunk_text)
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
request = ModelRequest(model=self._model_name, messages=messages)
tasks.append(self._llm_client.generate(request))

View File

@@ -0,0 +1,37 @@
from typing import Any, Optional
from dbgpt.core.awel.task.base import IN
from dbgpt.core.interface.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
"""The DBSchema Retriever Operator.
Args:
connection (RDBMSDatabase): The connection.
top_k (int, optional): The top k. Defaults to 4.
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
"""
def __init__(
self,
top_k: int = 4,
connection: Optional[RDBMSDatabase] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
super().__init__(**kwargs)
self._retriever = DBSchemaRetriever(
top_k=top_k,
connection=connection,
vector_store_connector=vector_store_connector,
)
def retrieve(self, query: IN) -> Any:
"""retrieve table schemas.
Args:
query (IN): query.
"""
return self._retriever.retrieve(query)

View File

@@ -0,0 +1,39 @@
from functools import reduce
from typing import Any, Optional
from dbgpt.core.awel.task.base import IN
from dbgpt.core.interface.retriever import RetrieverOperator
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.rag.retriever.rerank import Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
def __init__(
self,
top_k: int,
score_threshold: Optional[float] = 0.3,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Ranker = None,
vector_store_connector: VectorStoreConnector = None,
**kwargs
):
super().__init__(**kwargs)
self._score_threshold = score_threshold
self._retriever = EmbeddingRetriever(
top_k=top_k,
query_rewrite=query_rewrite,
rerank=rerank,
vector_store_connector=vector_store_connector,
)
def retrieve(self, query: IN) -> Any:
if isinstance(query, str):
return self._retriever.retrieve_with_scores(query, self._score_threshold)
elif isinstance(query, list):
candidates = [
self._retriever.retrieve_with_scores(q, self._score_threshold)
for q in query
]
return reduce(lambda x, y: x + y, candidates)

View File

@@ -0,0 +1,26 @@
from typing import Any, List, Optional
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN
from dbgpt.rag.knowledge.base import KnowledgeType, Knowledge
from dbgpt.rag.knowledge.factory import KnowledgeFactory
class KnowledgeOperator(MapOperator[Any, Any]):
"""Knowledge Operator."""
def __init__(
self, knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, **kwargs
):
"""Init the query rewrite operator.
Args:
knowledge_type: (Optional[KnowledgeType]) The knowledge type.
"""
super().__init__(**kwargs)
self._knowledge_type = knowledge_type
async def map(self, datasource: IN) -> Knowledge:
"""knowledge operator."""
return await self.blocking_func_to_async(
KnowledgeFactory.create, datasource, self._knowledge_type
)

View File

@@ -0,0 +1,43 @@
from typing import Any, Optional, List
from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.rerank import DefaultRanker
from dbgpt.rag.retriever.rewrite import QueryRewrite
class RerankOperator(MapOperator[Any, Any]):
"""The Rewrite Operator."""
def __init__(
self,
topk: Optional[int] = 3,
algorithm: Optional[str] = "default",
rank_fn: Optional[callable] = None,
**kwargs
):
"""Init the query rewrite operator.
Args:
topk (int): The number of the candidates.
algorithm (Optional[str]): The rerank algorithm name.
rank_fn (Optional[callable]): The rank function.
"""
super().__init__(**kwargs)
self._algorithm = algorithm
self._rerank = DefaultRanker(
topk=topk,
rank_fn=rank_fn,
)
async def map(self, candidates_with_scores: IN) -> List[Chunk]:
"""rerank the candidates.
Args:
candidates_with_scores (IN): The candidates with scores.
Returns:
List[Chunk]: The reranked candidates.
"""
return await self.blocking_func_to_async(
self._rerank.rank, candidates_with_scores
)

View File

@@ -0,0 +1,41 @@
from typing import Any, Optional, List
from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN
from dbgpt.rag.retriever.rewrite import QueryRewrite
class QueryRewriteOperator(MapOperator[Any, Any]):
"""The Rewrite Operator."""
def __init__(
self,
llm_client: Optional[LLMClient],
model_name: Optional[str] = None,
language: Optional[str] = "en",
nums: Optional[int] = 1,
**kwargs
):
"""Init the query rewrite operator.
Args:
llm_client (Optional[LLMClient]): The LLM client.
model_name (Optional[str]): The model name.
language (Optional[str]): The prompt language.
nums (Optional[int]): The number of the rewrite results.
"""
super().__init__(**kwargs)
self._nums = nums
self._rewrite = QueryRewrite(
llm_client=llm_client,
model_name=model_name,
language=language,
)
async def map(self, query_context: IN) -> List[str]:
"""Rewrite the query."""
query = query_context.get("query")
context = query_context.get("context")
return await self._rewrite.rewrite(
origin_query=query, context=context, nums=self._nums
)

View File

@@ -0,0 +1,49 @@
from typing import Any, Optional
from dbgpt.core import LLMClient
from dbgpt.core.awel.task.base import IN
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
from dbgpt.serve.rag.operators.base import AssemblerOperator
class SummaryAssemblerOperator(AssemblerOperator[Any, Any]):
def __init__(
self,
llm_client: Optional[LLMClient],
model_name: Optional[str] = "gpt-3.5-turbo",
language: Optional[str] = "en",
max_iteration_with_llm: Optional[int] = 5,
concurrency_limit_with_llm: Optional[int] = 3,
**kwargs
):
"""
Init the summary assemble operator.
Args:
llm_client: (Optional[LLMClient]) The LLM client.
model_name: (Optional[str]) The model name.
language: (Optional[str]) The prompt language.
max_iteration_with_llm: (Optional[int]) The max iteration with llm.
concurrency_limit_with_llm: (Optional[int]) The concurrency limit with llm.
"""
super().__init__(**kwargs)
self._llm_client = llm_client
self._model_name = model_name
self._language = language
self._max_iteration_with_llm = max_iteration_with_llm
self._concurrency_limit_with_llm = concurrency_limit_with_llm
async def map(self, knowledge: IN) -> Any:
"""Assemble the summary."""
assembler = SummaryAssembler.load_from_knowledge(
knowledge=knowledge,
llm_client=self._llm_client,
model_name=self._model_name,
language=self._language,
max_iteration_with_llm=self._max_iteration_with_llm,
concurrency_limit_with_llm=self._concurrency_limit_with_llm,
)
return await assembler.generate_summary()
def assemble(self, knowledge: IN) -> Any:
"""assemble knowledge for input value."""
pass

View File

@@ -1,6 +1,7 @@
from functools import reduce
from typing import List, Optional
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk
@@ -9,14 +10,13 @@ from dbgpt.rag.retriever.rerank import Ranker, DefaultRanker
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class DBStructRetriever(BaseRetriever):
"""DBStruct retriever."""
class DBSchemaRetriever(BaseRetriever):
"""DBSchema retriever."""
def __init__(
self,
top_k: int = 4,
connection: Optional[RDBMSDatabase] = None,
is_embeddings: bool = True,
query_rewrite: bool = False,
rerank: Ranker = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
@@ -26,14 +26,13 @@ class DBStructRetriever(BaseRetriever):
Args:
top_k (int): top k
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
is_embeddings (bool): Whether to query by embeddings in the vector store, Defaults to True.
query_rewrite (bool): query rewrite
rerank (Ranker): rerank
vector_store_connector (VectorStoreConnector): vector store connector
code example:
.. code-block:: python
>>> from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
>>> from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
>>> from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
>>> from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
>>> from dbgpt.rag.retriever.embedding import EmbeddingRetriever
@@ -71,16 +70,18 @@ class DBStructRetriever(BaseRetriever):
embedding_fn=embedding_fn
)
# get db struct retriever
retriever = DBStructRetriever(top_k=3, vector_store_connector=vector_connector)
retriever = DBSchemaRetriever(top_k=3, vector_store_connector=vector_connector)
chunks = retriever.retrieve("show columns from table")
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}")
"""
self._top_k = top_k
self._is_embeddings = is_embeddings
self._connection = connection
self._query_rewrite = query_rewrite
self._vector_store_connector = vector_store_connector
self._need_embeddings = False
if self._vector_store_connector:
self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k)
def _retrieve(self, query: str) -> List[Chunk]:
@@ -88,7 +89,7 @@ class DBStructRetriever(BaseRetriever):
Args:
query (str): query text
"""
if self._is_embeddings:
if self._need_embeddings:
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k)
@@ -97,8 +98,6 @@ class DBStructRetriever(BaseRetriever):
candidates = reduce(lambda x, y: x + y, candidates)
return candidates
else:
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
table_summaries = _parse_db_summary(self._connection)
return [Chunk(content=table_summary) for table_summary in table_summaries]
@@ -115,7 +114,7 @@ class DBStructRetriever(BaseRetriever):
Args:
query (str): query text
"""
if self._is_embeddings:
if self._need_embeddings:
queries = [query]
candidates = [self._similarity_search(query) for query in queries]
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
@@ -145,7 +144,7 @@ class DBStructRetriever(BaseRetriever):
self._top_k,
)
async def _aparse_db_summary(self) -> List[Chunk]:
async def _aparse_db_summary(self) -> List[str]:
"""Similar search."""
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary

View File

@@ -99,7 +99,12 @@ class EmbeddingRetriever(BaseRetriever):
"""
queries = [query]
if self._query_rewrite:
new_queries = await self._query_rewrite.rewrite(origin_query=query, nums=1)
candidates_tasks = [self._similarity_search(query) for query in queries]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
new_queries = await self._query_rewrite.rewrite(
origin_query=query, context=context, nums=1
)
queries.extend(new_queries)
candidates = [self._similarity_search(query) for query in queries]
candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
@@ -117,7 +122,12 @@ class EmbeddingRetriever(BaseRetriever):
"""
queries = [query]
if self._query_rewrite:
new_queries = await self._query_rewrite.rewrite(origin_query=query, nums=1)
candidates_tasks = [self._similarity_search(query) for query in queries]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
new_queries = await self._query_rewrite.rewrite(
origin_query=query, context=context, nums=1
)
queries.extend(new_queries)
candidates_with_score = [
self._similarity_search_with_score(query, score_threshold)
@@ -137,6 +147,12 @@ class EmbeddingRetriever(BaseRetriever):
self._top_k,
)
async def _run_async_tasks(self, tasks) -> List[Chunk]:
"""Run async tasks."""
candidates = await run_async_tasks(tasks=tasks, concurrency_limit=1)
candidates = reduce(lambda x, y: x + y, candidates)
return candidates
async def _similarity_search_with_score(
self, query, score_threshold
) -> List[Chunk]:

View File

@@ -2,14 +2,14 @@ from typing import List, Optional
from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType
REWRITE_PROMPT_TEMPLATE_EN = """
Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'\n":
"original query:: {original_query}\n"
"queries:\n"
Based on the given context {context}, Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'":
"original query:{original_query}\n"
"queries:"
"""
REWRITE_PROMPT_TEMPLATE_ZH = """请根据原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries<queries>'
"original_query{original_query}\n"
"queries\n"
REWRITE_PROMPT_TEMPLATE_ZH = """请根据上下文{context}, 将原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:<queries>'
"original_query:{original_query}\n"
"queries:"
"""
@@ -29,6 +29,7 @@ class QueryRewrite:
- query: (str), user query
- model_name: (str), llm model name
- llm_client: (Optional[LLMClient])
- language: (Optional[str]), language
"""
self._model_name = model_name
self._llm_client = llm_client
@@ -39,17 +40,22 @@ class QueryRewrite:
else REWRITE_PROMPT_TEMPLATE_ZH
)
async def rewrite(self, origin_query: str, nums: Optional[int] = 1) -> List[str]:
async def rewrite(
self, origin_query: str, context: Optional[str], nums: Optional[int] = 1
) -> List[str]:
"""query rewrite
Args:
origin_query: str original query
context: Optional[str] context
nums: Optional[int] rewrite nums
Returns:
queries: List[str]
"""
from dbgpt.util.chat_util import run_async_tasks
prompt = self._prompt_template.format(original_query=origin_query, nums=nums)
prompt = self._prompt_template.format(
context=context, original_query=origin_query, nums=nums
)
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
request = ModelRequest(model=self._model_name, messages=messages)
tasks = [self._llm_client.generate(request)]
@@ -61,8 +67,12 @@ class QueryRewrite:
queries,
)
)
print("rewrite queries:", queries)
return self._parse_llm_output(output=queries[0])
if len(queries) == 0:
print("llm generate no rewrite queries.")
return queries
new_queries = self._parse_llm_output(output=queries[0])[0:nums]
print(f"rewrite queries: {new_queries}")
return new_queries
def correct(self) -> List[str]:
pass
@@ -81,6 +91,8 @@ class QueryRewrite:
if response.startswith("queries:"):
response = response[len("queries:") :]
if response.startswith("queries"):
response = response[len("queries") :]
queries = response.split(",")
if len(queries) == 1:
@@ -90,6 +102,10 @@ class QueryRewrite:
if len(queries) == 1:
queries = response.split("")
for k in queries:
if k.startswith("queries:"):
k = k[len("queries:") :]
if k.startswith("queries"):
k = response[len("queries") :]
rk = k
if lowercase:
rk = rk.lower()

View File

@@ -4,7 +4,7 @@ from typing import List
import dbgpt
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.db_struct import DBStructRetriever
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
@@ -22,7 +22,7 @@ def mock_vector_store_connector():
@pytest.fixture
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
return DBStructRetriever(
return DBSchemaRetriever(
connection=mock_db_connection,
vector_store_connector=mock_vector_store_connector,
)

View File

@@ -53,9 +53,9 @@ class DBSummaryClient:
embedding_fn=self.embeddings,
vector_store_config=vector_store_config,
)
from dbgpt.rag.retriever.db_struct import DBStructRetriever
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
retriever = DBStructRetriever(
retriever = DBSchemaRetriever(
top_k=topk, vector_store_connector=vector_connector
)
table_docs = retriever.retrieve(query)
@@ -92,9 +92,9 @@ class DBSummaryClient:
vector_store_config=vector_store_config,
)
if not vector_connector.vector_name_exists():
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
db_assembler = DBStructAssembler.load_from_connection(
db_assembler = DBSchemaAssembler.load_from_connection(
connection=db_summary_client.db, vector_store_connector=vector_connector
)
if len(db_assembler.get_chunks()) > 0:

View File

@@ -1,7 +1,7 @@
"""Token splitter."""
from typing import Callable, List, Optional
from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel
from pydantic import BaseModel, Field, PrivateAttr
from dbgpt.util.global_helper import globals_helper
from dbgpt.util.splitter_utils import split_by_sep, split_by_char

View File

@@ -7,24 +7,24 @@ from dbgpt.rag.chunk_manager import ChunkParameters, ChunkManager
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge, ChunkStrategy
from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.rag.retriever.db_struct import DBStructRetriever
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.serve.rag.assembler.base import BaseAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class DBStructAssembler(BaseAssembler):
"""DBStructAssembler
class DBSchemaAssembler(BaseAssembler):
"""DBSchemaAssembler
Example:
.. code-block:: python
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
from dbgpt.serve.rag.assembler.db_struct import DBSchemaAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
connection = SQLiteTempConnect.create_temporary_db()
assembler = DBStructAssembler.load_from_connection(
assembler = DBSchemaAssembler.load_from_connection(
connection=connection,
embedding_model=embedding_model_path,
)
@@ -53,18 +53,21 @@ class DBStructAssembler(BaseAssembler):
"""
if connection is None:
raise ValueError("datasource connection must be provided.")
self._connection = connection
self._vector_store_connector = vector_store_connector
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
default_model_name=os.getenv("EMBEDDING_MODEL")
)
self._connection = connection
if embedding_model:
embedding_fn = embedding_factory.create(model_name=embedding_model)
self._vector_store_connector = (
vector_store_connector
or VectorStoreConnector.from_default(embedding_fn=embedding_fn)
)
self._embedding_model = embedding_model
if self._embedding_model:
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
default_model_name=self._embedding_model
)
self.embedding_fn = embedding_factory.create(self._embedding_model)
if self._vector_store_connector.vector_store_config.embedding_fn is None:
self._vector_store_connector.vector_store_config.embedding_fn = (
self.embedding_fn
)
super().__init__(
chunk_parameters=chunk_parameters,
**kwargs,
@@ -79,7 +82,7 @@ class DBStructAssembler(BaseAssembler):
embedding_model: Optional[str] = None,
embedding_factory: Optional[EmbeddingFactory] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
) -> "DBStructAssembler":
) -> "DBSchemaAssembler":
"""Load document embedding into vector store from path.
Args:
connection: (RDBMSDatabase) RDBMSDatabase connection.
@@ -89,13 +92,9 @@ class DBStructAssembler(BaseAssembler):
embedding_factory: (Optional[EmbeddingFactory]) EmbeddingFactory to use.
vector_store_connector: (Optional[VectorStoreConnector]) VectorStoreConnector to use.
Returns:
DBStructAssembler
DBSchemaAssembler
"""
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
default_model_name=embedding_model or os.getenv("EMBEDDING_MODEL_PATH")
)
embedding_factory = embedding_factory
chunk_parameters = chunk_parameters or ChunkParameters(
chunk_strategy=ChunkStrategy.CHUNK_BY_SIZE.name, chunk_overlap=0
)
@@ -136,14 +135,14 @@ class DBStructAssembler(BaseAssembler):
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
def as_retriever(self, top_k: Optional[int] = 4) -> DBStructRetriever:
def as_retriever(self, top_k: Optional[int] = 4) -> DBSchemaRetriever:
"""
Args:
top_k:(Optional[int]), default 4
Returns:
DBStructRetriever
DBSchemaRetriever
"""
return DBStructRetriever(
return DBSchemaRetriever(
top_k=top_k,
connection=self._connection,
is_embeddings=True,

View File

@@ -46,17 +46,19 @@ class EmbeddingAssembler(BaseAssembler):
"""
if knowledge is None:
raise ValueError("knowledge datasource must be provided.")
self._vector_store_connector = vector_store_connector
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
default_model_name=os.getenv("EMBEDDING_MODEL")
)
if embedding_model:
embedding_fn = embedding_factory.create(model_name=embedding_model)
self._vector_store_connector = (
vector_store_connector
or VectorStoreConnector.from_default(embedding_fn=embedding_fn)
)
self._embedding_model = embedding_model
if self._embedding_model:
embedding_factory = embedding_factory or DefaultEmbeddingFactory(
default_model_name=self._embedding_model
)
self.embedding_fn = embedding_factory.create(self._embedding_model)
if self._vector_store_connector.vector_store_config.embedding_fn is None:
self._vector_store_connector.vector_store_config.embedding_fn = (
self.embedding_fn
)
super().__init__(
knowledge=knowledge,

View File

@@ -57,9 +57,8 @@ class SummaryAssembler(BaseAssembler):
from dbgpt.rag.extractor.summary import SummaryExtractor
self._extractor = extractor or SummaryExtractor(
llm_client=self._llm_client, model_name=self._model_name
llm_client=self._llm_client, model_name=self._model_name, language=language
)
self._language = language
super().__init__(
knowledge=knowledge,
chunk_parameters=chunk_parameters,

View File

@@ -7,7 +7,7 @@ from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -66,7 +66,7 @@ def test_load_knowledge(
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = DBStructAssembler(
assembler = DBSchemaAssembler(
connection=mock_db_connection,
chunk_parameters=mock_chunk_parameters,
embedding_factory=mock_embedding_factory,

View File

View File

@@ -0,0 +1,23 @@
from abc import abstractmethod
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN, OUT
class AssemblerOperator(MapOperator[IN, OUT]):
"""The Base Assembler Operator."""
async def map(self, input_value: IN) -> OUT:
"""Map input value to output value.
Args:
input_value (IN): The input value.
Returns:
OUT: The output value.
"""
return await self.blocking_func_to_async(self.assemble, input_value)
@abstractmethod
def assemble(self, input_value: IN) -> OUT:
"""assemble knowledge for input value."""

View File

@@ -0,0 +1,36 @@
from typing import Any, Optional
from dbgpt.core.awel.task.base import IN
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.serve.rag.operators.base import AssemblerOperator
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class DBSchemaAssemblerOperator(AssemblerOperator[Any, Any]):
"""The DBSchema Assembler Operator.
Args:
connection (RDBMSDatabase): The connection.
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None.
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
"""
def __init__(
self,
connection: RDBMSDatabase = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
self._connection = connection
self._vector_store_connector = vector_store_connector
self._assembler = DBSchemaAssembler.load_from_connection(
connection=self._connection,
vector_store_connector=self._vector_store_connector,
)
super().__init__(**kwargs)
def assemble(self, input_value: IN) -> Any:
"""assemble knowledge for input value."""
if self._vector_store_connector:
self._assembler.persist()
return self._assembler.get_chunks()

View File

@@ -0,0 +1,44 @@
from typing import Any, Optional
from dbgpt.core.awel.task.base import IN
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.serve.rag.operators.base import AssemblerOperator
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class EmbeddingAssemblerOperator(AssemblerOperator[Any, Any]):
"""The Embedding Assembler Operator.
Args:
knowledge (Knowledge): The knowledge.
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to None.
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
"""
def __init__(
self,
chunk_parameters: Optional[ChunkParameters] = ChunkParameters(
chunk_strategy="CHUNK_BY_SIZE"
),
vector_store_connector: VectorStoreConnector = None,
**kwargs
):
"""
Args:
chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE").
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
"""
self._chunk_parameters = chunk_parameters
self._vector_store_connector = vector_store_connector
super().__init__(**kwargs)
def assemble(self, knowledge: IN) -> Any:
"""assemble knowledge for input value."""
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=self._chunk_parameters,
vector_store_connector=self._vector_store_connector,
)
assembler.persist()
return assembler.get_chunks()

View File

@@ -92,6 +92,11 @@ class VectorStoreConnector:
"""
return self.client.similar_search_with_scores(doc, topk, score_threshold)
@property
def vector_store_config(self) -> VectorStoreConfig:
"""vector store config."""
return self._vector_store_config
def vector_name_exists(self):
"""is vector store name exist."""
return self.client.vector_name_exists()

View File

@@ -0,0 +1,130 @@
import os
from typing import Dict, List
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.operator.db_schema import DBSchemaRetrieverOperator
from dbgpt.serve.rag.operators.db_schema import DBSchemaAssemblerOperator
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
"""AWEL: Simple rag db schema embedding operator example
if you not set vector_store_connector, it will return all tables schema in database.
```
retriever_task = DBSchemaRetrieverOperator(
connection=_create_temporary_connection()
)
```
if you set vector_store_connector, it will recall topk similarity tables schema in database.
```
retriever_task = DBSchemaRetrieverOperator(
connection=_create_temporary_connection()
top_k=1,
vector_store_connector=vector_store_connector
)
```
Examples:
..code-block:: shell
curl --location 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/dbschema' \
--header 'Content-Type: application/json' \
--data '{"query": "what is user name?"}'
"""
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="vector_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
def _create_temporary_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnect.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
return connect
def _join_fn(chunks: List[Chunk], query: str) -> str:
print(f"db schema info is {[chunk.content for chunk in chunks]}")
return query
class TriggerReqBody(BaseModel):
query: str = Field(..., description="User query")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
params = {
"query": input_value.query,
}
print(f"Receive input value: {input_value}")
return params
with DAG("simple_rag_db_schema_example") as dag:
trigger = HttpTrigger(
"/examples/rag/dbschema", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
query_operator = MapOperator(lambda request: request["query"])
vector_store_connector = _create_vector_connector()
assembler_task = DBSchemaAssemblerOperator(
connection=_create_temporary_connection(),
vector_store_connector=vector_store_connector,
)
join_operator = JoinOperator(combine_function=_join_fn)
retriever_task = DBSchemaRetrieverOperator(
connection=_create_temporary_connection(),
top_k=1,
vector_store_connector=vector_store_connector,
)
result_parse_task = MapOperator(lambda chunks: [chunk.content for chunk in chunks])
trigger >> request_handle_task >> assembler_task >> join_operator
trigger >> request_handle_task >> query_operator >> join_operator
join_operator >> retriever_task >> result_parse_task
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
pass

View File

@@ -0,0 +1,86 @@
import asyncio
import os
from typing import Dict, List
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, InputOperator, MapOperator, SimpleCallDataInputSource
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.operator.knowledge import KnowledgeOperator
from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
"""AWEL: Simple rag embedding operator example
pre-requirements:
set your file path in your example code.
Examples:
..code-block:: shell
python examples/awel/simple_rag_embedding_example.py
"""
def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
"""context Join function for JoinOperator.
Args:
context_dict (Dict): context dict
chunks (List[Chunk]): chunks
Returns:
Dict: context dict
"""
context_dict["context"] = "\n".join([chunk.content for chunk in chunks])
return context_dict
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="vector_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
class ResultOperator(MapOperator):
"""The Result Operator."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, chunks: List) -> str:
result = f"embedding success, there are {len(chunks)} chunks."
print(result)
return result
with DAG("simple_sdk_rag_embedding_example") as dag:
knowledge_operator = KnowledgeOperator()
vector_connector = _create_vector_connector()
input_task = InputOperator(input_source=SimpleCallDataInputSource())
file_path_parser = MapOperator(map_function=lambda x: x["file_path"])
embedding_operator = EmbeddingAssemblerOperator(
vector_store_connector=vector_connector,
)
output_task = ResultOperator()
(
input_task
>> file_path_parser
>> knowledge_operator
>> embedding_operator
>> output_task
)
if __name__ == "__main__":
input_data = {
"data": {
"file_path": "docs/docs/awel.md",
}
}
output = asyncio.run(output_task.call(call_data=input_data))

View File

@@ -0,0 +1,127 @@
import asyncio
import os
from typing import Dict, List
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.model import OpenAILLMClient
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.operator.embedding import EmbeddingRetrieverOperator
from dbgpt.rag.operator.rerank import RerankOperator
from dbgpt.rag.operator.rewrite import QueryRewriteOperator
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
"""AWEL: Simple rag embedding operator example
pre-requirements:
1. install openai python sdk
```
pip install openai
```
2. set openai key and base
```
export OPENAI_API_KEY={your_openai_key}
export OPENAI_API_BASE={your_openai_base}
```
3. make sure you have vector store.
if there are no data in vector store, please run examples/awel/simple_rag_embedding_example.py
ensure your embedding model in DB-GPT/models/.
Examples:
..code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5555"
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/rag/retrieve \
-H "Content-Type: application/json" -d '{
"query": "what is awel talk about?"
}'
"""
class TriggerReqBody(BaseModel):
query: str = Field(..., description="User query")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
params = {
"query": input_value.query,
}
print(f"Receive input value: {input_value}")
return params
def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
"""context Join function for JoinOperator.
Args:
context_dict (Dict): context dict
chunks (List[Chunk]): chunks
Returns:
Dict: context dict
"""
context_dict["context"] = "\n".join([chunk.content for chunk in chunks])
return context_dict
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="vector_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
with DAG("simple_sdk_rag_retriever_example") as dag:
vector_connector = _create_vector_connector()
trigger = HttpTrigger(
"/examples/rag/retrieve", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
query_parser = MapOperator(map_function=lambda x: x["query"])
context_join_operator = JoinOperator(combine_function=_context_join_fn)
rewrite_operator = QueryRewriteOperator(llm_client=OpenAILLMClient())
retriever_context_operator = EmbeddingRetrieverOperator(
top_k=3,
vector_store_connector=vector_connector,
)
retriever_operator = EmbeddingRetrieverOperator(
top_k=3,
vector_store_connector=vector_connector,
)
rerank_operator = RerankOperator()
model_parse_task = MapOperator(lambda out: out.to_dict())
trigger >> request_handle_task >> context_join_operator
(
trigger
>> request_handle_task
>> query_parser
>> retriever_context_operator
>> context_join_operator
)
context_join_operator >> rewrite_operator >> retriever_operator >> rerank_operator
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
pass

View File

@@ -0,0 +1,74 @@
"""AWEL: Simple rag rewrite example
pre-requirements:
1. install openai python sdk
```
pip install openai
```
2. set openai key and base
```
export OPENAI_API_KEY={your_openai_key}
export OPENAI_API_BASE={your_openai_base}
```
or
```
import os
os.environ["OPENAI_API_KEY"] = {your_openai_key}
os.environ["OPENAI_API_BASE"] = {your_openai_base}
```
python examples/awel/simple_rag_rewrite_example.py
Example:
.. code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5000"
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/rag/rewrite \
-H "Content-Type: application/json" -d '{
"query": "compare curry and james",
"context":"steve curry and lebron james are nba all-stars"
}'
"""
from typing import Dict
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.model import OpenAILLMClient
from dbgpt.rag.operator.rewrite import QueryRewriteOperator
class TriggerReqBody(BaseModel):
query: str = Field(..., description="User query")
context: str = Field(..., description="context")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
params = {
"query": input_value.query,
"context": input_value.context,
}
print(f"Receive input value: {input_value}")
return params
with DAG("dbgpt_awel_simple_rag_rewrite_example") as dag:
trigger = HttpTrigger(
"/examples/rag/rewrite", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
# build query rewrite operator
rewrite_task = QueryRewriteOperator(llm_client=OpenAILLMClient(), nums=2)
trigger >> request_handle_task >> rewrite_task
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
pass

View File

@@ -0,0 +1,84 @@
"""AWEL:
This example shows how to use AWEL to build a simple rag summary example.
pre-requirements:
1. install openai python sdk
```
pip install openai
```
2. set openai key and base
```
export OPENAI_API_KEY={your_openai_key}
export OPENAI_API_BASE={your_openai_base}
```
or
```
import os
os.environ["OPENAI_API_KEY"] = {your_openai_key}
os.environ["OPENAI_API_BASE"] = {your_openai_base}
```
python examples/awel/simple_rag_summary_example.py
Example:
.. code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5000"
FILE_PATH="{your_file_path}"
curl -X POST http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/summary \
-H "Content-Type: application/json" -d '{
"file_path": $FILE_PATH
}'
"""
from typing import Dict
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.model import OpenAILLMClient
from dbgpt.rag.operator.knowledge import KnowledgeOperator
from dbgpt.rag.operator.summary import SummaryAssemblerOperator
class TriggerReqBody(BaseModel):
file_path: str = Field(..., description="file_path")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
params = {
"file_path": input_value.file_path,
}
print(f"Receive input value: {input_value}")
return params
with DAG("dbgpt_awel_simple_rag_rewrite_example") as dag:
trigger = HttpTrigger(
"/examples/rag/summary", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
path_operator = MapOperator(lambda request: request["file_path"])
# build knowledge operator
knowledge_operator = KnowledgeOperator()
# build summary assembler operator
summary_operator = SummaryAssemblerOperator(
llm_client=OpenAILLMClient(), language="en"
)
(
trigger
>> request_handle_task
>> path_operator
>> knowledge_operator
>> summary_operator
)
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
pass

View File

@@ -1,6 +1,9 @@
import os
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -13,7 +16,7 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
Examples:
..code-block:: shell
python examples/rag/db_struct_rag_example.py
python examples/rag/db_schema_rag_example.py
"""
@@ -41,28 +44,29 @@ def _create_temporary_connection():
return connect
if __name__ == "__main__":
connection = _create_temporary_connection()
embedding_model_path = "{your_embedding_model_path}"
vector_persist_path = "{your_persist_path}"
embedding_fn = DefaultEmbeddingFactory(
default_model_name=embedding_model_path
).create()
vector_connector = VectorStoreConnector.from_default(
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="vector_name",
persist_path=vector_persist_path,
name="db_schema_vector_store_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=embedding_fn,
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
assembler = DBStructAssembler.load_from_connection(
if __name__ == "__main__":
connection = _create_temporary_connection()
vector_connector = _create_vector_connector()
assembler = DBSchemaAssembler.load_from_connection(
connection=connection,
vector_store_connector=vector_connector,
)
assembler.persist()
# get db struct retriever
# get db schema retriever
retriever = assembler.as_retriever(top_k=1)
chunks = retriever.retrieve("show columns from user")
print(f"db struct rag example results:{[chunk.content for chunk in chunks]}")
print(f"db schema rag example results:{[chunk.content for chunk in chunks]}")

View File

@@ -1,5 +1,7 @@
import asyncio
import os
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.knowledge.factory import KnowledgeFactory
@@ -20,21 +22,24 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
"""
async def main():
file_path = "./docs/docs/awel.md"
vector_persist_path = "{your_persist_path}"
embedding_model_path = "{your_embedding_model_path}"
knowledge = KnowledgeFactory.from_file_path(file_path)
vector_connector = VectorStoreConnector.from_default(
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="vector_name",
persist_path=vector_persist_path,
name="db_schema_vector_store_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=embedding_model_path
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
async def main():
file_path = "docs/docs/awel.md"
knowledge = KnowledgeFactory.from_file_path(file_path)
vector_connector = _create_vector_connector()
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
# get embedding assembler
assembler = EmbeddingAssembler.load_from_knowledge(
@@ -45,7 +50,7 @@ async def main():
assembler.persist()
# get embeddings retriever
retriever = assembler.as_retriever(3)
chunks = await retriever.aretrieve_with_scores("RAG", 0.3)
chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3)
print(f"embedding rag example results:{chunks}")