feat(rag): Support RAG SDK (#1322)

This commit is contained in:
Fangyin Cheng
2024-03-22 15:36:57 +08:00
committed by GitHub
parent e65732d6e4
commit 8a17099dd2
69 changed files with 1332 additions and 558 deletions

View File

@@ -1 +1,11 @@
"""Module of RAG."""
from dbgpt.core import Chunk, Document # noqa: F401
from .chunk_manager import ChunkParameters # noqa: F401
__ALL__ = [
"Chunk",
"Document",
"ChunkParameters",
]

View File

@@ -0,0 +1,16 @@
"""Assembler Module For RAG.
The Assembler is a module that is responsible for assembling the knowledge.
"""
from .base import BaseAssembler # noqa: F401
from .db_schema import DBSchemaAssembler # noqa: F401
from .embedding import EmbeddingAssembler # noqa: F401
from .summary import SummaryAssembler # noqa: F401
__all__ = [
"BaseAssembler",
"DBSchemaAssembler",
"EmbeddingAssembler",
"SummaryAssembler",
]

View File

@@ -0,0 +1,75 @@
"""Base Assembler."""
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from dbgpt.core import Chunk
from dbgpt.util.tracer import root_tracer
from ..chunk_manager import ChunkManager, ChunkParameters
from ..extractor.base import Extractor
from ..knowledge.base import Knowledge
from ..retriever.base import BaseRetriever
class BaseAssembler(ABC):
"""Base Assembler."""
def __init__(
self,
knowledge: Knowledge,
chunk_parameters: Optional[ChunkParameters] = None,
extractor: Optional[Extractor] = None,
**kwargs: Any,
) -> None:
"""Initialize with Assembler arguments.
Args:
knowledge(Knowledge): Knowledge datasource.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
extractor(Optional[Extractor]): Extractor to use for summarization.
"""
self._knowledge = knowledge
self._chunk_parameters = chunk_parameters or ChunkParameters()
self._extractor = extractor
self._chunk_manager = ChunkManager(
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
)
self._chunks: List[Chunk] = []
metadata = {
"knowledge_cls": self._knowledge.__class__.__name__
if self._knowledge
else None,
"knowledge_type": self._knowledge.type().value if self._knowledge else None,
"path": self._knowledge._path
if self._knowledge and hasattr(self._knowledge, "_path")
else None,
"chunk_parameters": self._chunk_parameters.dict(),
}
with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata):
self.load_knowledge(self._knowledge)
def load_knowledge(self, knowledge: Optional[Knowledge] = None) -> None:
"""Load knowledge Pipeline."""
if not knowledge:
raise ValueError("knowledge must be provided.")
with root_tracer.start_span("BaseAssembler.knowledge.load"):
documents = knowledge.load()
with root_tracer.start_span("BaseAssembler.chunk_manager.split"):
self._chunks = self._chunk_manager.split(documents)
@abstractmethod
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
"""Return a retriever."""
@abstractmethod
def persist(self) -> List[str]:
"""Persist chunks.
Returns:
List[str]: List of persisted chunk ids.
"""
def get_chunks(self) -> List[Chunk]:
"""Return chunks."""
return self._chunks

View File

@@ -0,0 +1,135 @@
"""DBSchemaAssembler."""
from typing import Any, List, Optional
from dbgpt.core import Chunk, Embeddings
from dbgpt.datasource.base import BaseConnector
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..embedding.embedding_factory import DefaultEmbeddingFactory
from ..knowledge.datasource import DatasourceKnowledge
from ..retriever.db_schema import DBSchemaRetriever
class DBSchemaAssembler(BaseAssembler):
"""DBSchemaAssembler.
Example:
.. code-block:: python
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.serve.rag.assembler.db_struct import DBSchemaAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
connection = SQLiteTempConnector.create_temporary_db()
assembler = DBSchemaAssembler.load_from_connection(
connector=connection,
embedding_model=embedding_model_path,
)
assembler.persist()
# get db struct retriever
retriever = assembler.as_retriever(top_k=3)
"""
def __init__(
self,
connector: BaseConnector,
vector_store_connector: VectorStoreConnector,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
connector: (BaseConnector) BaseConnector connection.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
"""
knowledge = DatasourceKnowledge(connector)
self._connector = connector
self._vector_store_connector = vector_store_connector
self._embedding_model = embedding_model
if self._embedding_model and not embeddings:
embeddings = DefaultEmbeddingFactory(
default_model_name=self._embedding_model
).create(self._embedding_model)
if (
embeddings
and self._vector_store_connector.vector_store_config.embedding_fn is None
):
self._vector_store_connector.vector_store_config.embedding_fn = embeddings
super().__init__(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
**kwargs,
)
@classmethod
def load_from_connection(
cls,
connector: BaseConnector,
vector_store_connector: VectorStoreConnector,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
) -> "DBSchemaAssembler":
"""Load document embedding into vector store from path.
Args:
connector: (BaseConnector) BaseConnector connection.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
Returns:
DBSchemaAssembler
"""
return cls(
connector=connector,
vector_store_connector=vector_store_connector,
embedding_model=embedding_model,
chunk_parameters=chunk_parameters,
embeddings=embeddings,
)
def get_chunks(self) -> List[Chunk]:
"""Return chunk ids."""
return self._chunks
def persist(self) -> List[str]:
"""Persist chunks into vector store.
Returns:
List[str]: List of chunk ids.
"""
return self._vector_store_connector.load_document(self._chunks)
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
return []
def as_retriever(self, top_k: int = 4, **kwargs) -> DBSchemaRetriever:
"""Create DBSchemaRetriever.
Args:
top_k(int): default 4.
Returns:
DBSchemaRetriever
"""
return DBSchemaRetriever(
top_k=top_k,
connector=self._connector,
is_embeddings=True,
vector_store_connector=self._vector_store_connector,
)

View File

@@ -0,0 +1,124 @@
"""Embedding Assembler."""
from typing import Any, List, Optional
from dbgpt.core import Chunk, Embeddings
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..embedding.embedding_factory import DefaultEmbeddingFactory
from ..knowledge.base import Knowledge
from ..retriever.embedding import EmbeddingRetriever
class EmbeddingAssembler(BaseAssembler):
"""Embedding Assembler.
Example:
.. code-block:: python
from dbgpt.rag.assembler import EmbeddingAssembler
pdf_path = "path/to/document.pdf"
knowledge = KnowledgeFactory.from_file_path(pdf_path)
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
embedding_model="text2vec",
)
"""
def __init__(
self,
knowledge: Knowledge,
vector_store_connector: VectorStoreConnector,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
knowledge: (Knowledge) Knowledge datasource.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
"""
if knowledge is None:
raise ValueError("knowledge datasource must be provided.")
self._vector_store_connector = vector_store_connector
self._embedding_model = embedding_model
if self._embedding_model and not embeddings:
embeddings = DefaultEmbeddingFactory(
default_model_name=self._embedding_model
).create(self._embedding_model)
if (
embeddings
and self._vector_store_connector.vector_store_config.embedding_fn is None
):
self._vector_store_connector.vector_store_config.embedding_fn = embeddings
super().__init__(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
**kwargs,
)
@classmethod
def load_from_knowledge(
cls,
knowledge: Knowledge,
vector_store_connector: VectorStoreConnector,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
) -> "EmbeddingAssembler":
"""Load document embedding into vector store from path.
Args:
knowledge: (Knowledge) Knowledge datasource.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
Returns:
EmbeddingAssembler
"""
return cls(
knowledge=knowledge,
vector_store_connector=vector_store_connector,
chunk_parameters=chunk_parameters,
embedding_model=embedding_model,
embeddings=embeddings,
)
def persist(self) -> List[str]:
"""Persist chunks into vector store.
Returns:
List[str]: List of chunk ids.
"""
return self._vector_store_connector.load_document(self._chunks)
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
return []
def as_retriever(self, top_k: int = 4, **kwargs) -> EmbeddingRetriever:
"""Create a retriever.
Args:
top_k(int): default 4.
Returns:
EmbeddingRetriever
"""
return EmbeddingRetriever(
top_k=top_k, vector_store_connector=self._vector_store_connector
)

View File

@@ -0,0 +1,131 @@
"""Summary Assembler."""
import os
from typing import Any, List, Optional
from dbgpt.core import Chunk, LLMClient
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..extractor.base import Extractor
from ..knowledge.base import Knowledge
from ..retriever.base import BaseRetriever
class SummaryAssembler(BaseAssembler):
"""Summary Assembler.
Example:
.. code-block:: python
pdf_path = "../../../DB-GPT/docs/docs/awel.md"
OPEN_AI_KEY = "{your_api_key}"
OPEN_AI_BASE = "{your_api_base}"
llm_client = OpenAILLMClient(api_key=OPEN_AI_KEY, api_base=OPEN_AI_BASE)
knowledge = KnowledgeFactory.from_file_path(pdf_path)
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
assembler = SummaryAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
llm_client=llm_client,
model_name="gpt-3.5-turbo",
)
summary = await assembler.generate_summary()
"""
def __init__(
self,
knowledge: Knowledge,
chunk_parameters: Optional[ChunkParameters] = None,
model_name: Optional[str] = None,
llm_client: Optional[LLMClient] = None,
extractor: Optional[Extractor] = None,
language: Optional[str] = "en",
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
knowledge: (Knowledge) Knowledge datasource.
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
model_name: (Optional[str]) llm model to use.
llm_client: (Optional[LLMClient]) LLMClient to use.
extractor: (Optional[Extractor]) Extractor to use for summarization.
language: (Optional[str]) The language of the prompt. Defaults to "en".
"""
if knowledge is None:
raise ValueError("knowledge datasource must be provided.")
model_name = model_name or os.getenv("LLM_MODEL")
if not extractor:
from ..extractor.summary import SummaryExtractor
if not llm_client:
raise ValueError("llm_client must be provided.")
if not model_name:
raise ValueError("model_name must be provided.")
extractor = SummaryExtractor(
llm_client=llm_client,
model_name=model_name,
language=language,
)
if not extractor:
raise ValueError("extractor must be provided.")
self._extractor: Extractor = extractor
super().__init__(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
extractor=self._extractor,
**kwargs,
)
@classmethod
def load_from_knowledge(
cls,
knowledge: Knowledge,
chunk_parameters: Optional[ChunkParameters] = None,
model_name: Optional[str] = None,
llm_client: Optional[LLMClient] = None,
extractor: Optional[Extractor] = None,
language: Optional[str] = "en",
**kwargs: Any,
) -> "SummaryAssembler":
"""Load document embedding into vector store from path.
Args:
knowledge: (Knowledge) Knowledge datasource.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
model_name: (Optional[str]) llm model to use.
llm_client: (Optional[LLMClient]) LLMClient to use.
extractor: (Optional[Extractor]) Extractor to use for summarization.
language: (Optional[str]) The language of the prompt. Defaults to "en".
Returns:
SummaryAssembler
"""
return cls(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
model_name=model_name,
llm_client=llm_client,
extractor=extractor,
language=language,
**kwargs,
)
async def generate_summary(self) -> str:
"""Generate summary."""
return await self._extractor.aextract(self._chunks)
def persist(self) -> List[str]:
"""Persist chunks into store."""
raise NotImplementedError
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
return []
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
"""Return a retriever."""
raise NotImplementedError

View File

View File

@@ -0,0 +1,76 @@
from unittest.mock import MagicMock
import pytest
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@pytest.fixture
def mock_db_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnector.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
return connect
@pytest.fixture
def mock_chunk_parameters():
return MagicMock(spec=ChunkParameters)
@pytest.fixture
def mock_embedding_factory():
return MagicMock(spec=EmbeddingFactory)
@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=VectorStoreConnector)
@pytest.fixture
def mock_knowledge():
return MagicMock(spec=Knowledge)
def test_load_knowledge(
mock_db_connection,
mock_knowledge,
mock_chunk_parameters,
mock_embedding_factory,
mock_vector_store_connector,
):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = EmbeddingAssembler(
knowledge=mock_knowledge,
chunk_parameters=mock_chunk_parameters,
embeddings=mock_embedding_factory.create(),
vector_store_connector=mock_vector_store_connector,
)
assembler.load_knowledge(knowledge=mock_knowledge)
assert len(assembler._chunks) == 0

View File

@@ -0,0 +1,68 @@
from unittest.mock import MagicMock
import pytest
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@pytest.fixture
def mock_db_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnector.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
return connect
@pytest.fixture
def mock_chunk_parameters():
return MagicMock(spec=ChunkParameters)
@pytest.fixture
def mock_embedding_factory():
return MagicMock(spec=EmbeddingFactory)
@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=VectorStoreConnector)
def test_load_knowledge(
mock_db_connection,
mock_chunk_parameters,
mock_embedding_factory,
mock_vector_store_connector,
):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = DBSchemaAssembler(
connector=mock_db_connection,
chunk_parameters=mock_chunk_parameters,
embeddings=mock_embedding_factory.create(),
vector_store_connector=mock_vector_store_connector,
)
assert len(assembler._chunks) == 1

View File

@@ -1,6 +1,10 @@
"""Module for embedding related classes and functions."""
from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory # noqa: F401
from .embedding_factory import ( # noqa: F401
DefaultEmbeddingFactory,
EmbeddingFactory,
WrappedEmbeddingFactory,
)
from .embeddings import ( # noqa: F401
Embeddings,
HuggingFaceBgeEmbeddings,
@@ -21,4 +25,5 @@ __ALL__ = [
"OpenAPIEmbeddings",
"DefaultEmbeddingFactory",
"EmbeddingFactory",
"WrappedEmbeddingFactory",
]

View File

@@ -0,0 +1,32 @@
"""Wraps the third-party language model embeddings to the common interface."""
from typing import TYPE_CHECKING, List
from dbgpt.core import Embeddings
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings as LangChainEmbeddings
class WrappedEmbeddings(Embeddings):
"""Wraps the third-party language model embeddings to the common interface."""
def __init__(self, embeddings: "LangChainEmbeddings") -> None:
"""Create a new WrappedEmbeddings."""
self._embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
return self._embeddings.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self._embeddings.embed_query(text)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
return await self._embeddings.aembed_documents(texts)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await self._embeddings.aembed_query(text)

View File

@@ -1,15 +1,14 @@
"""EmbeddingFactory class and DefaultEmbeddingFactory class."""
from __future__ import annotations
import logging
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, Type
from typing import Any, Optional, Type
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.rag.embedding.embeddings import HuggingFaceEmbeddings
from dbgpt.core import Embeddings
if TYPE_CHECKING:
from dbgpt.rag.embedding.embeddings import Embeddings
logger = logging.getLogger(__name__)
class EmbeddingFactory(BaseComponent, ABC):
@@ -20,7 +19,7 @@ class EmbeddingFactory(BaseComponent, ABC):
@abstractmethod
def create(
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
) -> "Embeddings":
) -> Embeddings:
"""Create an embedding instance.
Args:
@@ -39,12 +38,19 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
self,
system_app: Optional[SystemApp] = None,
default_model_name: Optional[str] = None,
default_model_path: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Create a new DefaultEmbeddingFactory."""
super().__init__(system_app=system_app)
if not default_model_path:
default_model_path = default_model_name
if not default_model_name:
default_model_name = default_model_path
self._default_model_name = default_model_name
self.kwargs = kwargs
self._default_model_path = default_model_path
self._kwargs = kwargs
self._model = self._load_model()
def init_app(self, system_app):
"""Init the app."""
@@ -52,20 +58,166 @@ class DefaultEmbeddingFactory(EmbeddingFactory):
def create(
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
) -> "Embeddings":
) -> Embeddings:
"""Create an embedding instance.
Args:
model_name (str): The model name.
embedding_cls (Type): The embedding class.
"""
if not model_name:
model_name = self._default_model_name
new_kwargs = {k: v for k, v in self.kwargs.items()}
new_kwargs["model_name"] = model_name
if embedding_cls:
return embedding_cls(**new_kwargs)
else:
return HuggingFaceEmbeddings(**new_kwargs)
raise NotImplementedError
return self._model
def _load_model(self) -> Embeddings:
from dbgpt.model.adapter.embeddings_loader import (
EmbeddingLoader,
_parse_embedding_params,
)
from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
)
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters
)
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
model_name=self._default_model_name,
model_path=self._default_model_path,
param_cls=param_cls,
**self._kwargs,
)
logger.info(model_params)
loader = EmbeddingLoader()
# Ignore model_name args
model_name = self._default_model_name or model_params.model_name
if not model_name:
raise ValueError("model_name must be provided.")
return loader.load(model_name, model_params)
@classmethod
def openai(
cls,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
model_name: str = "text-embedding-3-small",
timeout: int = 60,
**kwargs: Any,
) -> Embeddings:
"""Create an OpenAI embeddings.
If api_url and api_key are not provided, we will try to get them from
environment variables.
Args:
api_url (Optional[str], optional): The api url. Defaults to None.
api_key (Optional[str], optional): The api key. Defaults to None.
model_name (str, optional): The model name.
Defaults to "text-embedding-3-small".
timeout (int, optional): The timeout. Defaults to 60.
Returns:
Embeddings: The embeddings instance.
"""
api_url = (
api_url
or os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") + "/embeddings"
)
api_key = api_key or os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("api_key must be provided.")
return cls.remote(
api_url=api_url,
api_key=api_key,
model_name=model_name,
timeout=timeout,
**kwargs,
)
@classmethod
def default(
cls, model_name: str, model_path: Optional[str] = None, **kwargs: Any
) -> Embeddings:
"""Create a default embeddings.
It will try to load the model from the model name or model path.
Args:
model_name (str): The model name.
model_path (Optional[str], optional): The model path. Defaults to None.
if not provided, it will use the model name as the model path to load
the model.
Returns:
Embeddings: The embeddings instance.
"""
return cls(
default_model_name=model_name, default_model_path=model_path, **kwargs
).create()
@classmethod
def remote(
cls,
api_url: str = "http://localhost:8100/api/v1/embeddings",
api_key: Optional[str] = None,
model_name: str = "text2vec",
timeout: int = 60,
**kwargs: Any,
) -> Embeddings:
"""Create a remote embeddings.
Create a remote embeddings which API compatible with the OpenAI's API. So if
your model is compatible with OpenAI's API, you can use this method to create
a remote embeddings.
Args:
api_url (str, optional): The api url. Defaults to
"http://localhost:8100/api/v1/embeddings".
api_key (Optional[str], optional): The api key. Defaults to None.
model_name (str, optional): The model name. Defaults to "text2vec".
timeout (int, optional): The timeout. Defaults to 60.
"""
from .embeddings import OpenAPIEmbeddings
return OpenAPIEmbeddings(
api_url=api_url,
api_key=api_key,
model_name=model_name,
timeout=timeout,
**kwargs,
)
class WrappedEmbeddingFactory(EmbeddingFactory):
"""The default embedding factory."""
def __init__(
self,
system_app: Optional[SystemApp] = None,
embeddings: Optional[Embeddings] = None,
**kwargs: Any,
) -> None:
"""Create a new DefaultEmbeddingFactory."""
super().__init__(system_app=system_app)
if not embeddings:
raise ValueError("embeddings must be provided.")
self._model = embeddings
def init_app(self, system_app):
"""Init the app."""
pass
def create(
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
) -> Embeddings:
"""Create an embedding instance.
Args:
model_name (str): The model name.
embedding_cls (Type): The embedding class.
"""
if embedding_cls:
raise NotImplementedError
return self._model

View File

@@ -1,23 +1,50 @@
"""Module Of Knowledge."""
from .base import ChunkStrategy, Knowledge, KnowledgeType # noqa: F401
from .csv import CSVKnowledge # noqa: F401
from .docx import DocxKnowledge # noqa: F401
from .factory import KnowledgeFactory # noqa: F401
from .html import HTMLKnowledge # noqa: F401
from .markdown import MarkdownKnowledge # noqa: F401
from .pdf import PDFKnowledge # noqa: F401
from .pptx import PPTXKnowledge # noqa: F401
from .string import StringKnowledge # noqa: F401
from .txt import TXTKnowledge # noqa: F401
from .url import URLKnowledge # noqa: F401
from typing import Any, Dict
__ALL__ = [
_MODULE_CACHE: Dict[str, Any] = {}
def __getattr__(name: str):
# Lazy load
import importlib
if name in _MODULE_CACHE:
return _MODULE_CACHE[name]
_LIBS = {
"KnowledgeFactory": "factory",
"Knowledge": "base",
"KnowledgeType": "base",
"ChunkStrategy": "base",
"CSVKnowledge": "csv",
"DatasourceKnowledge": "datasource",
"DocxKnowledge": "docx",
"HTMLKnowledge": "html",
"MarkdownKnowledge": "markdown",
"PDFKnowledge": "pdf",
"PPTXKnowledge": "pptx",
"StringKnowledge": "string",
"TXTKnowledge": "txt",
"URLKnowledge": "url",
}
if name in _LIBS:
module_path = "." + _LIBS[name]
module = importlib.import_module(module_path, __name__)
attr = getattr(module, name)
_MODULE_CACHE[name] = attr
return attr
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = [
"KnowledgeFactory",
"Knowledge",
"KnowledgeType",
"ChunkStrategy",
"CSVKnowledge",
"DatasourceKnowledge",
"DocxKnowledge",
"HTMLKnowledge",
"MarkdownKnowledge",

View File

@@ -25,6 +25,7 @@ class DocumentType(Enum):
DOCX = "docx"
TXT = "txt"
HTML = "html"
DATASOURCE = "datasource"
class KnowledgeType(Enum):

View File

@@ -0,0 +1,57 @@
"""Datasource Knowledge."""
from typing import Any, List, Optional
from dbgpt.core import Document
from dbgpt.datasource import BaseConnector
from ..summary.rdbms_db_summary import _parse_db_summary
from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType
class DatasourceKnowledge(Knowledge):
"""Datasource Knowledge."""
def __init__(
self,
connector: BaseConnector,
summary_template: str = "{table_name}({columns})",
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
**kwargs: Any,
) -> None:
"""Create Datasource Knowledge with Knowledge arguments.
Args:
path(str, optional): file path
knowledge_type(KnowledgeType, optional): knowledge type
data_loader(Any, optional): loader
"""
self._connector = connector
self._summary_template = summary_template
super().__init__(knowledge_type=knowledge_type, **kwargs)
def _load(self) -> List[Document]:
"""Load datasource document from data_loader."""
docs = []
for table_summary in _parse_db_summary(self._connector, self._summary_template):
docs.append(
Document(content=table_summary, metadata={"source": "database"})
)
return docs
@classmethod
def support_chunk_strategy(cls) -> List[ChunkStrategy]:
"""Return support chunk strategy."""
return [
ChunkStrategy.CHUNK_BY_SIZE,
ChunkStrategy.CHUNK_BY_SEPARATOR,
]
@classmethod
def type(cls) -> KnowledgeType:
"""Knowledge type of Datasource."""
return KnowledgeType.DOCUMENT
@classmethod
def document_type(cls) -> DocumentType:
"""Return document type."""
return DocumentType.DATASOURCE

View File

@@ -156,6 +156,7 @@ class KnowledgeFactory:
"""Get all knowledge subclasses."""
from dbgpt.rag.knowledge.base import Knowledge # noqa: F401
from dbgpt.rag.knowledge.csv import CSVKnowledge # noqa: F401
from dbgpt.rag.knowledge.datasource import DatasourceKnowledge # noqa: F401
from dbgpt.rag.knowledge.docx import DocxKnowledge # noqa: F401
from dbgpt.rag.knowledge.html import HTMLKnowledge # noqa: F401
from dbgpt.rag.knowledge.markdown import MarkdownKnowledge # noqa: F401

View File

@@ -1,8 +1,14 @@
"""Module for RAG operators."""
from .datasource import DatasourceRetrieverOperator # noqa: F401
from .db_schema import DBSchemaRetrieverOperator # noqa: F401
from .embedding import EmbeddingRetrieverOperator # noqa: F401
from .db_schema import ( # noqa: F401
DBSchemaAssemblerOperator,
DBSchemaRetrieverOperator,
)
from .embedding import ( # noqa: F401
EmbeddingAssemblerOperator,
EmbeddingRetrieverOperator,
)
from .evaluation import RetrieverEvaluatorOperator # noqa: F401
from .knowledge import KnowledgeOperator # noqa: F401
from .rerank import RerankOperator # noqa: F401
@@ -12,7 +18,9 @@ from .summary import SummaryAssemblerOperator # noqa: F401
__all__ = [
"DatasourceRetrieverOperator",
"DBSchemaRetrieverOperator",
"DBSchemaAssemblerOperator",
"EmbeddingRetrieverOperator",
"EmbeddingAssemblerOperator",
"KnowledgeOperator",
"RerankOperator",
"QueryRewriteOperator",

View File

@@ -0,0 +1,24 @@
"""Base Assembler Operator."""
from abc import abstractmethod
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN, OUT
class AssemblerOperator(MapOperator[IN, OUT]):
"""The Base Assembler Operator."""
async def map(self, input_value: IN) -> OUT:
"""Map input value to output value.
Args:
input_value (IN): The input value.
Returns:
OUT: The output value.
"""
return await self.blocking_func_to_async(self.assemble, input_value)
@abstractmethod
def assemble(self, input_value: IN) -> OUT:
"""Assemble knowledge for input value."""

View File

@@ -1,21 +1,21 @@
"""Datasource operator for RDBMS database."""
from typing import Any
from typing import Any, List
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]):
class DatasourceRetrieverOperator(RetrieverOperator[Any, List[str]]):
"""The Datasource Retriever Operator."""
def __init__(self, connection: RDBMSConnector, **kwargs):
def __init__(self, connector: BaseConnector, **kwargs):
"""Create a new DatasourceRetrieverOperator."""
super().__init__(**kwargs)
self._connection = connection
self._connector = connector
def retrieve(self, input_value: Any) -> Any:
def retrieve(self, input_value: Any) -> List[str]:
"""Retrieve the database summary."""
summary = _parse_db_summary(self._connection)
summary = _parse_db_summary(self._connector)
return summary

View File

@@ -1,18 +1,22 @@
"""The DBSchema Retriever Operator."""
from typing import Any, Optional
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.datasource.base import BaseConnector
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from ..assembler.db_schema import DBSchemaAssembler
from ..retriever.db_schema import DBSchemaRetriever
from .assembler import AssemblerOperator
class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
"""The DBSchema Retriever Operator.
Args:
connection (RDBMSConnector): The connection.
connector (BaseConnector): The connection.
top_k (int, optional): The top k. Defaults to 4.
vector_store_connector (VectorStoreConnector, optional): The vector store
connector. Defaults to None.
@@ -22,21 +26,57 @@ class DBSchemaRetrieverOperator(RetrieverOperator[Any, Any]):
self,
vector_store_connector: VectorStoreConnector,
top_k: int = 4,
connection: Optional[RDBMSConnector] = None,
connector: Optional[BaseConnector] = None,
**kwargs
):
"""Create a new DBSchemaRetrieverOperator."""
super().__init__(**kwargs)
self._retriever = DBSchemaRetriever(
top_k=top_k,
connection=connection,
connector=connector,
vector_store_connector=vector_store_connector,
)
def retrieve(self, query: Any) -> Any:
def retrieve(self, query: str) -> List[Chunk]:
"""Retrieve the table schemas.
Args:
query (IN): query.
query (str): The query.
"""
return self._retriever.retrieve(query)
class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
"""The DBSchema Assembler Operator."""
def __init__(
self,
connector: BaseConnector,
vector_store_connector: VectorStoreConnector,
**kwargs
):
"""Create a new DBSchemaAssemblerOperator.
Args:
connector (BaseConnector): The connection.
vector_store_connector (VectorStoreConnector): The vector store connector.
"""
self._vector_store_connector = vector_store_connector
self._connector = connector
super().__init__(**kwargs)
def assemble(self, dummy_value) -> List[Chunk]:
"""Persist the database schema.
Args:
dummy_value: Dummy value, not used.
Returns:
List[Chunk]: The chunks.
"""
assembler = DBSchemaAssembler.load_from_connection(
connector=self._connector,
vector_store_connector=self._vector_store_connector,
)
assembler.persist()
return assembler.get_chunks()

View File

@@ -5,11 +5,16 @@ from typing import List, Optional, Union
from dbgpt.core import Chunk
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.rag.retriever.rerank import Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from ..assembler.embedding import EmbeddingAssembler
from ..chunk_manager import ChunkParameters
from ..knowledge import Knowledge
from ..retriever.embedding import EmbeddingRetriever
from ..retriever.rerank import Ranker
from ..retriever.rewrite import QueryRewrite
from .assembler import AssemblerOperator
class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[Chunk]]):
"""The Embedding Retriever Operator."""
@@ -43,3 +48,36 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[C
for q in query
]
return reduce(lambda x, y: x + y, candidates)
class EmbeddingAssemblerOperator(AssemblerOperator[Knowledge, List[Chunk]]):
"""The Embedding Assembler Operator."""
def __init__(
self,
vector_store_connector: VectorStoreConnector,
chunk_parameters: Optional[ChunkParameters] = ChunkParameters(
chunk_strategy="CHUNK_BY_SIZE"
),
**kwargs
):
"""Create a new EmbeddingAssemblerOperator.
Args:
vector_store_connector (VectorStoreConnector): The vector store connector.
chunk_parameters (Optional[ChunkParameters], optional): The chunk
parameters. Defaults to ChunkParameters(chunk_strategy="CHUNK_BY_SIZE").
"""
self._chunk_parameters = chunk_parameters
self._vector_store_connector = vector_store_connector
super().__init__(**kwargs)
def assemble(self, knowledge: Knowledge) -> List[Chunk]:
"""Assemble knowledge for input value."""
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=self._chunk_parameters,
vector_store_connector=self._vector_store_connector,
)
assembler.persist()
return assembler.get_chunks()

View File

@@ -1,6 +1,6 @@
"""Knowledge Operator."""
from typing import Any, Optional
from typing import Optional
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import (
@@ -14,7 +14,7 @@ from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
from dbgpt.rag.knowledge.factory import KnowledgeFactory
class KnowledgeOperator(MapOperator[Any, Any]):
class KnowledgeOperator(MapOperator[str, Knowledge]):
"""Knowledge Factory Operator."""
metadata = ViewMetadata(
@@ -26,7 +26,7 @@ class KnowledgeOperator(MapOperator[Any, Any]):
IOField.build_from(
"knowledge datasource",
"knowledge datasource",
dict,
str,
"knowledge datasource",
)
],
@@ -85,7 +85,7 @@ class KnowledgeOperator(MapOperator[Any, Any]):
self._datasource = datasource
self._knowledge_type = KnowledgeType.get_by_value(knowledge_type)
async def map(self, datasource: Any) -> Knowledge:
async def map(self, datasource: str) -> Knowledge:
"""Create knowledge from datasource."""
if self._datasource:
datasource = self._datasource

View File

@@ -1,12 +1,12 @@
"""The Rerank Operator."""
from typing import Any, List, Optional
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.core.awel import MapOperator
from dbgpt.rag.retriever.rerank import RANK_FUNC, DefaultRanker
class RerankOperator(MapOperator[Any, Any]):
class RerankOperator(MapOperator[List[Chunk], List[Chunk]]):
"""The Rewrite Operator."""
def __init__(

View File

@@ -7,7 +7,7 @@ from typing import Any, Optional
from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -17,7 +17,7 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
def __init__(
self,
connection: RDBMSConnector,
connector: BaseConnector,
model_name: str,
llm: LLMClient,
top_k: int = 5,
@@ -27,14 +27,14 @@ class SchemaLinkingOperator(MapOperator[Any, Any]):
"""Create the schema linking operator.
Args:
connection (RDBMSConnector): The connection.
connector (BaseConnector): The connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._schema_linking = SchemaLinking(
top_k=top_k,
connection=connection,
connector=connector,
llm=llm,
model_name=model_name,
vector_store_connector=vector_store_connector,

View File

@@ -4,9 +4,9 @@ from typing import Any, Optional
from dbgpt.core import LLMClient
from dbgpt.core.awel.flow import IOField, OperatorCategory, Parameter, ViewMetadata
from dbgpt.rag.assembler.summary import SummaryAssembler
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
from dbgpt.serve.rag.operators.base import AssemblerOperator
from dbgpt.rag.operators.assembler import AssemblerOperator
class SummaryAssemblerOperator(AssemblerOperator[Any, Any]):

View File

@@ -3,7 +3,7 @@ from functools import reduce
from typing import List, Optional, cast
from dbgpt.core import Chunk
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
@@ -18,7 +18,7 @@ class DBSchemaRetriever(BaseRetriever):
self,
vector_store_connector: VectorStoreConnector,
top_k: int = 4,
connection: Optional[RDBMSConnector] = None,
connector: Optional[BaseConnector] = None,
query_rewrite: bool = False,
rerank: Optional[Ranker] = None,
**kwargs
@@ -28,7 +28,7 @@ class DBSchemaRetriever(BaseRetriever):
Args:
vector_store_connector (VectorStoreConnector): vector store connector
top_k (int): top k
connection (Optional[RDBMSConnector]): RDBMSConnector connection.
connector (Optional[BaseConnector]): RDBMSConnector.
query_rewrite (bool): query rewrite
rerank (Ranker): rerank
@@ -65,7 +65,7 @@ class DBSchemaRetriever(BaseRetriever):
return connect
connection = _create_temporary_connection()
connector = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
@@ -76,14 +76,16 @@ class DBSchemaRetriever(BaseRetriever):
)
# get db struct retriever
retriever = DBSchemaRetriever(
top_k=3, vector_store_connector=vector_connector
top_k=3,
vector_store_connector=vector_connector,
connector=connector,
)
chunks = retriever.retrieve("show columns from table")
result = [chunk.content for chunk in chunks]
print(f"db struct rag example results:{result}")
"""
self._top_k = top_k
self._connection = connection
self._connector = connector
self._query_rewrite = query_rewrite
self._vector_store_connector = vector_store_connector
self._need_embeddings = False
@@ -108,9 +110,9 @@ class DBSchemaRetriever(BaseRetriever):
]
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
else:
if not self._connection:
if not self._connector:
raise RuntimeError("RDBMSConnector connection is required.")
table_summaries = _parse_db_summary(self._connection)
table_summaries = _parse_db_summary(self._connector)
return [Chunk(content=table_summary) for table_summary in table_summaries]
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
@@ -173,6 +175,6 @@ class DBSchemaRetriever(BaseRetriever):
"""Similar search."""
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
if not self._connection:
if not self._connector:
raise RuntimeError("RDBMSConnector connection is required.")
return _parse_db_summary(self._connection)
return _parse_db_summary(self._connector)

View File

@@ -24,7 +24,7 @@ def mock_vector_store_connector():
@pytest.fixture
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
return DBSchemaRetriever(
connection=mock_db_connection,
connector=mock_db_connection,
vector_store_connector=mock_vector_store_connector,
)

View File

@@ -10,7 +10,7 @@ from dbgpt.core import (
ModelMessageRoleType,
ModelRequest,
)
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -42,7 +42,7 @@ class SchemaLinking(BaseSchemaLinker):
def __init__(
self,
connection: RDBMSConnector,
connector: BaseConnector,
model_name: str,
llm: LLMClient,
top_k: int = 5,
@@ -52,19 +52,19 @@ class SchemaLinking(BaseSchemaLinker):
"""Create the schema linking instance.
Args:
connection (Optional[RDBMSConnector]): RDBMSConnector connection.
connection (Optional[BaseConnector]): BaseConnector connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._top_k = top_k
self._connection = connection
self._connector = connector
self._llm = llm
self._model_name = model_name
self._vector_store_connector = vector_store_connector
def _schema_linking(self, query: str) -> List:
"""Get all db schema info."""
table_summaries = _parse_db_summary(self._connection)
table_summaries = _parse_db_summary(self._connector)
chunks = [Chunk(content=table_summary) for table_summary in table_summaries]
chunks_content = [chunk.content for chunk in chunks]
return chunks_content

View File

@@ -97,10 +97,10 @@ class DBSummaryClient:
vector_store_config=vector_store_config,
)
if not vector_connector.vector_name_exists():
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
db_assembler = DBSchemaAssembler.load_from_connection(
connection=db_summary_client.db, vector_store_connector=vector_connector
connector=db_summary_client.db, vector_store_connector=vector_connector
)
if len(db_assembler.get_chunks()) > 0:
db_assembler.persist()

View File

@@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, List, Optional
from dbgpt._private.config import Config
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.datasource import BaseConnector
from dbgpt.rag.summary.db_summary import DBSummary
if TYPE_CHECKING:
@@ -64,12 +64,12 @@ class RdbmsSummary(DBSummary):
def _parse_db_summary(
conn: RDBMSConnector, summary_template: str = "{table_name}({columns})"
conn: BaseConnector, summary_template: str = "{table_name}({columns})"
) -> List[str]:
"""Get db summary for database.
Args:
conn (RDBMSConnector): database connection
conn (BaseConnector): database connection
summary_template (str): summary template
"""
tables = conn.get_table_names()
@@ -81,12 +81,12 @@ def _parse_db_summary(
def _parse_table_summary(
conn: RDBMSConnector, summary_template: str, table_name: str
conn: BaseConnector, summary_template: str, table_name: str
) -> str:
"""Get table summary for table.
Args:
conn (RDBMSConnector): database connection
conn (BaseConnector): database connection
summary_template (str): summary template
table_name (str): table name