mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
chore: Fix package name conflict error (#1099)
This commit is contained in:
0
dbgpt/rag/operators/__init__.py
Normal file
0
dbgpt/rag/operators/__init__.py
Normal file
15
dbgpt/rag/operators/datasource.py
Normal file
15
dbgpt/rag/operators/datasource.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import Any
|
||||
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
|
||||
class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
def __init__(self, connection: RDBMSDatabase, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._connection = connection
|
||||
|
||||
def retrieve(self, input_value: Any) -> Any:
|
||||
summary = _parse_db_summary(self._connection)
|
||||
return summary
|
37
dbgpt/rag/operators/db_schema.py
Normal file
37
dbgpt/rag/operators/db_schema.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
"""The DBSchema Retriever Operator.
|
||||
Args:
|
||||
connection (RDBMSDatabase): The connection.
|
||||
top_k (int, optional): The top k. Defaults to 4.
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int = 4,
|
||||
connection: Optional[RDBMSDatabase] = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._retriever = DBSchemaRetriever(
|
||||
top_k=top_k,
|
||||
connection=connection,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
|
||||
def retrieve(self, query: IN) -> Any:
|
||||
"""retrieve table schemas.
|
||||
Args:
|
||||
query (IN): query.
|
||||
"""
|
||||
return self._retriever.retrieve(query)
|
39
dbgpt/rag/operators/embedding.py
Normal file
39
dbgpt/rag/operators/embedding.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from functools import reduce
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float] = 0.3,
|
||||
query_rewrite: Optional[QueryRewrite] = None,
|
||||
rerank: Ranker = None,
|
||||
vector_store_connector: VectorStoreConnector = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._score_threshold = score_threshold
|
||||
self._retriever = EmbeddingRetriever(
|
||||
top_k=top_k,
|
||||
query_rewrite=query_rewrite,
|
||||
rerank=rerank,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
|
||||
def retrieve(self, query: IN) -> Any:
|
||||
if isinstance(query, str):
|
||||
return self._retriever.retrieve_with_scores(query, self._score_threshold)
|
||||
elif isinstance(query, list):
|
||||
candidates = [
|
||||
self._retriever.retrieve_with_scores(q, self._score_threshold)
|
||||
for q in query
|
||||
]
|
||||
return reduce(lambda x, y: x + y, candidates)
|
26
dbgpt/rag/operators/knowledge.py
Normal file
26
dbgpt/rag/operators/knowledge.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
|
||||
|
||||
class KnowledgeOperator(MapOperator[Any, Any]):
|
||||
"""Knowledge Operator."""
|
||||
|
||||
def __init__(
|
||||
self, knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, **kwargs
|
||||
):
|
||||
"""Init the query rewrite operator.
|
||||
Args:
|
||||
knowledge_type: (Optional[KnowledgeType]) The knowledge type.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._knowledge_type = knowledge_type
|
||||
|
||||
async def map(self, datasource: IN) -> Knowledge:
|
||||
"""knowledge operator."""
|
||||
return await self.blocking_func_to_async(
|
||||
KnowledgeFactory.create, datasource, self._knowledge_type
|
||||
)
|
43
dbgpt/rag/operators/rerank.py
Normal file
43
dbgpt/rag/operators/rerank.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
|
||||
|
||||
class RerankOperator(MapOperator[Any, Any]):
|
||||
"""The Rewrite Operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
topk: Optional[int] = 3,
|
||||
algorithm: Optional[str] = "default",
|
||||
rank_fn: Optional[callable] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Init the query rewrite operator.
|
||||
Args:
|
||||
topk (int): The number of the candidates.
|
||||
algorithm (Optional[str]): The rerank algorithm name.
|
||||
rank_fn (Optional[callable]): The rank function.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._algorithm = algorithm
|
||||
self._rerank = DefaultRanker(
|
||||
topk=topk,
|
||||
rank_fn=rank_fn,
|
||||
)
|
||||
|
||||
async def map(self, candidates_with_scores: IN) -> List[Chunk]:
|
||||
"""rerank the candidates.
|
||||
Args:
|
||||
candidates_with_scores (IN): The candidates with scores.
|
||||
Returns:
|
||||
List[Chunk]: The reranked candidates.
|
||||
"""
|
||||
return await self.blocking_func_to_async(
|
||||
self._rerank.rank, candidates_with_scores
|
||||
)
|
41
dbgpt/rag/operators/rewrite.py
Normal file
41
dbgpt/rag/operators/rewrite.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
|
||||
|
||||
class QueryRewriteOperator(MapOperator[Any, Any]):
|
||||
"""The Rewrite Operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient],
|
||||
model_name: Optional[str] = None,
|
||||
language: Optional[str] = "en",
|
||||
nums: Optional[int] = 1,
|
||||
**kwargs
|
||||
):
|
||||
"""Init the query rewrite operator.
|
||||
Args:
|
||||
llm_client (Optional[LLMClient]): The LLM client.
|
||||
model_name (Optional[str]): The model name.
|
||||
language (Optional[str]): The prompt language.
|
||||
nums (Optional[int]): The number of the rewrite results.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._nums = nums
|
||||
self._rewrite = QueryRewrite(
|
||||
llm_client=llm_client,
|
||||
model_name=model_name,
|
||||
language=language,
|
||||
)
|
||||
|
||||
async def map(self, query_context: IN) -> List[str]:
|
||||
"""Rewrite the query."""
|
||||
query = query_context.get("query")
|
||||
context = query_context.get("context")
|
||||
return await self._rewrite.rewrite(
|
||||
origin_query=query, context=context, nums=self._nums
|
||||
)
|
44
dbgpt/rag/operators/schema_linking.py
Normal file
44
dbgpt/rag/operators/schema_linking.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class SchemaLinkingOperator(MapOperator[Any, Any]):
|
||||
"""The Schema Linking Operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int = 5,
|
||||
connection: Optional[RDBMSDatabase] = None,
|
||||
llm: Optional[LLMClient] = None,
|
||||
model_name: Optional[str] = None,
|
||||
vector_store_connector: Optional[VectorStoreConnector] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Init the schema linking operator
|
||||
Args:
|
||||
connection (RDBMSDatabase): The connection.
|
||||
llm (Optional[LLMClient]): base llm
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._schema_linking = SchemaLinking(
|
||||
top_k=top_k,
|
||||
connection=connection,
|
||||
llm=llm,
|
||||
model_name=model_name,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
|
||||
async def map(self, query: str) -> str:
|
||||
"""retrieve table schemas.
|
||||
Args:
|
||||
query (str): query.
|
||||
Return:
|
||||
str: schema info
|
||||
"""
|
||||
return str(await self._schema_linking.schema_linking_with_llm(query))
|
49
dbgpt/rag/operators/summary.py
Normal file
49
dbgpt/rag/operators/summary.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel.task.base import IN
|
||||
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.serve.rag.operators.base import AssemblerOperator
|
||||
|
||||
|
||||
class SummaryAssemblerOperator(AssemblerOperator[Any, Any]):
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient],
|
||||
model_name: Optional[str] = "gpt-3.5-turbo",
|
||||
language: Optional[str] = "en",
|
||||
max_iteration_with_llm: Optional[int] = 5,
|
||||
concurrency_limit_with_llm: Optional[int] = 3,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Init the summary assemble operator.
|
||||
Args:
|
||||
llm_client: (Optional[LLMClient]) The LLM client.
|
||||
model_name: (Optional[str]) The model name.
|
||||
language: (Optional[str]) The prompt language.
|
||||
max_iteration_with_llm: (Optional[int]) The max iteration with llm.
|
||||
concurrency_limit_with_llm: (Optional[int]) The concurrency limit with llm.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._llm_client = llm_client
|
||||
self._model_name = model_name
|
||||
self._language = language
|
||||
self._max_iteration_with_llm = max_iteration_with_llm
|
||||
self._concurrency_limit_with_llm = concurrency_limit_with_llm
|
||||
|
||||
async def map(self, knowledge: IN) -> Any:
|
||||
"""Assemble the summary."""
|
||||
assembler = SummaryAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
llm_client=self._llm_client,
|
||||
model_name=self._model_name,
|
||||
language=self._language,
|
||||
max_iteration_with_llm=self._max_iteration_with_llm,
|
||||
concurrency_limit_with_llm=self._concurrency_limit_with_llm,
|
||||
)
|
||||
return await assembler.generate_summary()
|
||||
|
||||
def assemble(self, knowledge: IN) -> Any:
|
||||
"""assemble knowledge for input value."""
|
||||
pass
|
Reference in New Issue
Block a user