feat(RAG):add metadata properties filters (#1395)

This commit is contained in:
Aries-ckt 2024-04-10 14:33:24 +08:00 committed by GitHub
parent 0f2b46da62
commit 37e7c0151b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 619 additions and 166 deletions

View File

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import Any, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type, Union
from dbgpt.core import Document from dbgpt.core import Document
from dbgpt.rag.text_splitter.text_splitter import ( from dbgpt.rag.text_splitter.text_splitter import (
@ -147,16 +147,18 @@ class Knowledge(ABC):
self, self,
path: Optional[str] = None, path: Optional[str] = None,
knowledge_type: Optional[KnowledgeType] = None, knowledge_type: Optional[KnowledgeType] = None,
data_loader: Optional[Any] = None, loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize with Knowledge arguments.""" """Initialize with Knowledge arguments."""
self._path = path self._path = path
self._type = knowledge_type self._type = knowledge_type
self._data_loader = data_loader self._loader = loader
self._metadata = metadata
def load(self) -> List[Document]: def load(self) -> List[Document]:
"""Load knowledge from data_loader.""" """Load knowledge from data loader."""
documents = self._load() documents = self._load()
return self._postprocess(documents) return self._postprocess(documents)
@ -171,12 +173,12 @@ class Knowledge(ABC):
return None return None
def _postprocess(self, docs: List[Document]) -> List[Document]: def _postprocess(self, docs: List[Document]) -> List[Document]:
"""Post process knowledge from data_loader.""" """Post process knowledge from data loader."""
return docs return docs
@abstractmethod @abstractmethod
def _load(self) -> List[Document]: def _load(self) -> List[Document]:
"""Preprocess knowledge from data_loader.""" """Preprocess knowledge from data loader."""
@classmethod @classmethod
def support_chunk_strategy(cls) -> List[ChunkStrategy]: def support_chunk_strategy(cls) -> List[ChunkStrategy]:

View File

@ -1,6 +1,6 @@
"""CSV Knowledge.""" """CSV Knowledge."""
import csv import csv
from typing import Any, List, Optional from typing import Any, Dict, List, Optional, Union
from dbgpt.core import Document from dbgpt.core import Document
from dbgpt.rag.knowledge.base import ( from dbgpt.rag.knowledge.base import (
@ -21,6 +21,7 @@ class CSVKnowledge(Knowledge):
source_column: Optional[str] = None, source_column: Optional[str] = None,
encoding: Optional[str] = "utf-8", encoding: Optional[str] = "utf-8",
loader: Optional[Any] = None, loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Create CSV Knowledge with Knowledge arguments. """Create CSV Knowledge with Knowledge arguments.
@ -32,9 +33,13 @@ class CSVKnowledge(Knowledge):
encoding(str, optional): csv encoding encoding(str, optional): csv encoding
loader(Any, optional): loader loader(Any, optional): loader
""" """
self._path = file_path super().__init__(
self._type = knowledge_type path=file_path,
self._loader = loader knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
self._encoding = encoding self._encoding = encoding
self._source_column = source_column self._source_column = source_column
@ -67,6 +72,8 @@ class CSVKnowledge(Knowledge):
f"file." f"file."
) )
metadata = {"source": source, "row": i} metadata = {"source": source, "row": i}
if self._metadata:
metadata.update(self._metadata) # type: ignore
doc = Document(content=content, metadata=metadata) doc = Document(content=content, metadata=metadata)
docs.append(doc) docs.append(doc)

View File

@ -1,5 +1,5 @@
"""Datasource Knowledge.""" """Datasource Knowledge."""
from typing import Any, List, Optional from typing import Any, Dict, List, Optional, Union
from dbgpt.core import Document from dbgpt.core import Document
from dbgpt.datasource import BaseConnector from dbgpt.datasource import BaseConnector
@ -16,26 +16,29 @@ class DatasourceKnowledge(Knowledge):
connector: BaseConnector, connector: BaseConnector,
summary_template: str = "{table_name}({columns})", summary_template: str = "{table_name}({columns})",
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Create Datasource Knowledge with Knowledge arguments. """Create Datasource Knowledge with Knowledge arguments.
Args: Args:
path(str, optional): file path connector(BaseConnector): connector
summary_template(str, optional): summary template
knowledge_type(KnowledgeType, optional): knowledge type knowledge_type(KnowledgeType, optional): knowledge type
data_loader(Any, optional): loader metadata(Dict[str, Union[str, List[str]], optional): metadata
""" """
self._connector = connector self._connector = connector
self._summary_template = summary_template self._summary_template = summary_template
super().__init__(knowledge_type=knowledge_type, **kwargs) super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs)
def _load(self) -> List[Document]: def _load(self) -> List[Document]:
"""Load datasource document from data_loader.""" """Load datasource document from data_loader."""
docs = [] docs = []
for table_summary in _parse_db_summary(self._connector, self._summary_template): for table_summary in _parse_db_summary(self._connector, self._summary_template):
docs.append( metadata = {"source": "database"}
Document(content=table_summary, metadata={"source": "database"}) if self._metadata:
) metadata.update(self._metadata) # type: ignore
docs.append(Document(content=table_summary, metadata=metadata))
return docs return docs
@classmethod @classmethod

View File

@ -1,5 +1,5 @@
"""Docx Knowledge.""" """Docx Knowledge."""
from typing import Any, List, Optional from typing import Any, Dict, List, Optional, Union
import docx import docx
@ -21,6 +21,7 @@ class DocxKnowledge(Knowledge):
knowledge_type: Any = KnowledgeType.DOCUMENT, knowledge_type: Any = KnowledgeType.DOCUMENT,
encoding: Optional[str] = "utf-8", encoding: Optional[str] = "utf-8",
loader: Optional[Any] = None, loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Create Docx Knowledge with Knowledge arguments. """Create Docx Knowledge with Knowledge arguments.
@ -31,9 +32,13 @@ class DocxKnowledge(Knowledge):
encoding(str, optional): csv encoding encoding(str, optional): csv encoding
loader(Any, optional): loader loader(Any, optional): loader
""" """
self._path = file_path super().__init__(
self._type = knowledge_type path=file_path,
self._loader = loader knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
self._encoding = encoding self._encoding = encoding
def _load(self) -> List[Document]: def _load(self) -> List[Document]:
@ -48,9 +53,10 @@ class DocxKnowledge(Knowledge):
para = doc.paragraphs[i] para = doc.paragraphs[i]
text = para.text text = para.text
content.append(text) content.append(text)
docs.append( metadata = {"source": self._path}
Document(content="\n".join(content), metadata={"source": self._path}) if self._metadata:
) metadata.update(self._metadata) # type: ignore
docs.append(Document(content="\n".join(content), metadata=metadata))
return docs return docs
return [Document.langchain2doc(lc_document) for lc_document in documents] return [Document.langchain2doc(lc_document) for lc_document in documents]

View File

@ -1,5 +1,5 @@
"""Knowledge Factory to create knowledge from file path and url.""" """Knowledge Factory to create knowledge from file path and url."""
from typing import List, Optional, Type from typing import Dict, List, Optional, Type, Union
from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
from dbgpt.rag.knowledge.string import StringKnowledge from dbgpt.rag.knowledge.string import StringKnowledge
@ -13,6 +13,7 @@ class KnowledgeFactory:
self, self,
file_path: Optional[str] = None, file_path: Optional[str] = None,
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
): ):
"""Create Knowledge Factory with file path and knowledge type. """Create Knowledge Factory with file path and knowledge type.
@ -22,18 +23,21 @@ class KnowledgeFactory:
""" """
self._file_path = file_path self._file_path = file_path
self._knowledge_type = knowledge_type self._knowledge_type = knowledge_type
self._metadata = metadata
@classmethod @classmethod
def create( def create(
cls, cls,
datasource: str = "", datasource: str = "",
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT, knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
): ):
"""Create knowledge from file path, url or text. """Create knowledge from file path, url or text.
Args: Args:
datasource: path of the file to convert datasource: path of the file to convert
knowledge_type: type of knowledge knowledge_type: type of knowledge
metadata: Optional[Dict[str, Union[str, List[str]]]]
Examples: Examples:
.. code-block:: python .. code-block:: python
@ -52,12 +56,16 @@ class KnowledgeFactory:
match knowledge_type: match knowledge_type:
case KnowledgeType.DOCUMENT: case KnowledgeType.DOCUMENT:
return cls.from_file_path( return cls.from_file_path(
file_path=datasource, knowledge_type=knowledge_type file_path=datasource,
knowledge_type=knowledge_type,
metadata=metadata,
) )
case KnowledgeType.URL: case KnowledgeType.URL:
return cls.from_url(url=datasource, knowledge_type=knowledge_type) return cls.from_url(url=datasource, knowledge_type=knowledge_type)
case KnowledgeType.TEXT: case KnowledgeType.TEXT:
return cls.from_text(text=datasource, knowledge_type=knowledge_type) return cls.from_text(
text=datasource, knowledge_type=knowledge_type, metadata=metadata
)
case _: case _:
raise Exception(f"Unsupported knowledge type '{knowledge_type}'") raise Exception(f"Unsupported knowledge type '{knowledge_type}'")
@ -66,6 +74,7 @@ class KnowledgeFactory:
cls, cls,
file_path: str = "", file_path: str = "",
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
) -> Knowledge: ) -> Knowledge:
"""Create knowledge from path. """Create knowledge from path.
@ -82,10 +91,11 @@ class KnowledgeFactory:
datasource="path/to/document.pdf", datasource="path/to/document.pdf",
knowledge_type=KnowledgeType.DOCUMENT, knowledge_type=KnowledgeType.DOCUMENT,
) )
""" """
factory = cls(file_path=file_path, knowledge_type=knowledge_type) factory = cls(file_path=file_path, knowledge_type=knowledge_type)
return factory._select_document_knowledge( return factory._select_document_knowledge(
file_path=file_path, knowledge_type=knowledge_type file_path=file_path, knowledge_type=knowledge_type, metadata=metadata
) )
@staticmethod @staticmethod
@ -117,6 +127,7 @@ class KnowledgeFactory:
def from_text( def from_text(
text: str = "", text: str = "",
knowledge_type: KnowledgeType = KnowledgeType.TEXT, knowledge_type: KnowledgeType = KnowledgeType.TEXT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
) -> Knowledge: ) -> Knowledge:
"""Create knowledge from text. """Create knowledge from text.
@ -127,6 +138,7 @@ class KnowledgeFactory:
return StringKnowledge( return StringKnowledge(
text=text, text=text,
knowledge_type=knowledge_type, knowledge_type=knowledge_type,
metadata=metadata,
) )
def _select_document_knowledge(self, **kwargs): def _select_document_knowledge(self, **kwargs):

View File

@ -1,5 +1,5 @@
"""HTML Knowledge.""" """HTML Knowledge."""
from typing import Any, List, Optional from typing import Any, Dict, List, Optional, Union
import chardet import chardet
@ -20,6 +20,7 @@ class HTMLKnowledge(Knowledge):
file_path: Optional[str] = None, file_path: Optional[str] = None,
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT, knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
loader: Optional[Any] = None, loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Create HTML Knowledge with Knowledge arguments. """Create HTML Knowledge with Knowledge arguments.
@ -29,9 +30,13 @@ class HTMLKnowledge(Knowledge):
knowledge_type(KnowledgeType, optional): knowledge type knowledge_type(KnowledgeType, optional): knowledge type
loader(Any, optional): loader loader(Any, optional): loader
""" """
self._path = file_path super().__init__(
self._type = knowledge_type path=file_path,
self._loader = loader knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
def _load(self) -> List[Document]: def _load(self) -> List[Document]:
"""Load html document from loader.""" """Load html document from loader."""
@ -48,6 +53,8 @@ class HTMLKnowledge(Knowledge):
else: else:
text = raw_text.decode(result["encoding"]) text = raw_text.decode(result["encoding"])
metadata = {"source": self._path} metadata = {"source": self._path}
if self._metadata:
metadata.update(self._metadata) # type: ignore
return [Document(content=text, metadata=metadata)] return [Document(content=text, metadata=metadata)]
return [Document.langchain2doc(lc_document) for lc_document in documents] return [Document.langchain2doc(lc_document) for lc_document in documents]

View File

@ -1,5 +1,5 @@
"""Markdown Knowledge.""" """Markdown Knowledge."""
from typing import Any, List, Optional from typing import Any, Dict, List, Optional, Union
from dbgpt.core import Document from dbgpt.core import Document
from dbgpt.rag.knowledge.base import ( from dbgpt.rag.knowledge.base import (
@ -19,6 +19,7 @@ class MarkdownKnowledge(Knowledge):
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT, knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
encoding: Optional[str] = "utf-8", encoding: Optional[str] = "utf-8",
loader: Optional[Any] = None, loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Create Markdown Knowledge with Knowledge arguments. """Create Markdown Knowledge with Knowledge arguments.
@ -29,9 +30,13 @@ class MarkdownKnowledge(Knowledge):
encoding(str, optional): csv encoding encoding(str, optional): csv encoding
loader(Any, optional): loader loader(Any, optional): loader
""" """
self._path = file_path super().__init__(
self._type = knowledge_type path=file_path,
self._loader = loader knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
self._encoding = encoding self._encoding = encoding
def _load(self) -> List[Document]: def _load(self) -> List[Document]:
@ -44,6 +49,8 @@ class MarkdownKnowledge(Knowledge):
with open(self._path, encoding=self._encoding, errors="ignore") as f: with open(self._path, encoding=self._encoding, errors="ignore") as f:
markdown_text = f.read() markdown_text = f.read()
metadata = {"source": self._path} metadata = {"source": self._path}
if self._metadata:
metadata.update(self._metadata) # type: ignore
documents = [Document(content=markdown_text, metadata=metadata)] documents = [Document(content=markdown_text, metadata=metadata)]
return documents return documents
return [Document.langchain2doc(lc_document) for lc_document in documents] return [Document.langchain2doc(lc_document) for lc_document in documents]

View File

@ -1,5 +1,5 @@
"""PDF Knowledge.""" """PDF Knowledge."""
from typing import Any, List, Optional from typing import Any, Dict, List, Optional, Union
from dbgpt.core import Document from dbgpt.core import Document
from dbgpt.rag.knowledge.base import ( from dbgpt.rag.knowledge.base import (
@ -19,6 +19,7 @@ class PDFKnowledge(Knowledge):
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT, knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
loader: Optional[Any] = None, loader: Optional[Any] = None,
language: Optional[str] = "zh", language: Optional[str] = "zh",
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Create PDF Knowledge with Knowledge arguments. """Create PDF Knowledge with Knowledge arguments.
@ -29,9 +30,13 @@ class PDFKnowledge(Knowledge):
loader(Any, optional): loader loader(Any, optional): loader
language(str, optional): language language(str, optional): language
""" """
self._path = file_path super().__init__(
self._type = knowledge_type path=file_path,
self._loader = loader knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
self._language = language self._language = language
def _load(self) -> List[Document]: def _load(self) -> List[Document]:
@ -65,6 +70,8 @@ class PDFKnowledge(Knowledge):
page = "\n".join(cleaned_lines) page = "\n".join(cleaned_lines)
# cleaned_pages.append(page) # cleaned_pages.append(page)
metadata = {"source": self._path, "page": page_num} metadata = {"source": self._path, "page": page_num}
if self._metadata:
metadata.update(self._metadata) # type: ignore
# text = "\f".join(cleaned_pages) # text = "\f".join(cleaned_pages)
document = Document(content=page, metadata=metadata) document = Document(content=page, metadata=metadata)
documents.append(document) documents.append(document)

View File

@ -1,5 +1,5 @@
"""PPTX Knowledge.""" """PPTX Knowledge."""
from typing import Any, List, Optional from typing import Any, Dict, List, Optional, Union
from dbgpt.core import Document from dbgpt.core import Document
from dbgpt.rag.knowledge.base import ( from dbgpt.rag.knowledge.base import (
@ -19,6 +19,7 @@ class PPTXKnowledge(Knowledge):
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT, knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
loader: Optional[Any] = None, loader: Optional[Any] = None,
language: Optional[str] = "zh", language: Optional[str] = "zh",
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Create PPTX knowledge with PDF Knowledge arguments. """Create PPTX knowledge with PDF Knowledge arguments.
@ -28,9 +29,13 @@ class PPTXKnowledge(Knowledge):
knowledge_type:(KnowledgeType) knowledge type knowledge_type:(KnowledgeType) knowledge type
loader:(Optional[Any]) loader loader:(Optional[Any]) loader
""" """
self._path = file_path super().__init__(
self._type = knowledge_type path=file_path,
self._loader = loader knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
self._language = language self._language = language
def _load(self) -> List[Document]: def _load(self) -> List[Document]:
@ -47,9 +52,10 @@ class PPTXKnowledge(Knowledge):
for shape in slide.shapes: for shape in slide.shapes:
if hasattr(shape, "text") and shape.text: if hasattr(shape, "text") and shape.text:
content += shape.text content += shape.text
docs.append( metadata = {"source": self._path}
Document(content=content, metadata={"source": slide.slide_id}) if self._metadata:
) metadata.update(self._metadata) # type: ignore
docs.append(Document(content=content, metadata=metadata))
return docs return docs
return [Document.langchain2doc(lc_document) for lc_document in documents] return [Document.langchain2doc(lc_document) for lc_document in documents]

View File

@ -1,5 +1,5 @@
"""String Knowledge.""" """String Knowledge."""
from typing import Any, List, Optional from typing import Any, Dict, List, Optional, Union
from dbgpt.core import Document from dbgpt.core import Document
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType
@ -14,6 +14,7 @@ class StringKnowledge(Knowledge):
knowledge_type: KnowledgeType = KnowledgeType.TEXT, knowledge_type: KnowledgeType = KnowledgeType.TEXT,
encoding: Optional[str] = "utf-8", encoding: Optional[str] = "utf-8",
loader: Optional[Any] = None, loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Create String knowledge parameters. """Create String knowledge parameters.
@ -24,14 +25,20 @@ class StringKnowledge(Knowledge):
encoding(str): encoding encoding(str): encoding
loader(Any): loader loader(Any): loader
""" """
super().__init__(
knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
self._text = text self._text = text
self._type = knowledge_type
self._loader = loader
self._encoding = encoding self._encoding = encoding
def _load(self) -> List[Document]: def _load(self) -> List[Document]:
"""Load raw text from loader.""" """Load raw text from loader."""
metadata = {"source": "raw text"} metadata = {"source": "raw text"}
if self._metadata:
metadata.update(self._metadata) # type: ignore
docs = [Document(content=self._text, metadata=metadata)] docs = [Document(content=self._text, metadata=metadata)]
return docs return docs

View File

@ -1,5 +1,5 @@
"""TXT Knowledge.""" """TXT Knowledge."""
from typing import Any, List, Optional from typing import Any, Dict, List, Optional, Union
import chardet import chardet
@ -20,6 +20,7 @@ class TXTKnowledge(Knowledge):
file_path: Optional[str] = None, file_path: Optional[str] = None,
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT, knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
loader: Optional[Any] = None, loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Create TXT Knowledge with Knowledge arguments. """Create TXT Knowledge with Knowledge arguments.
@ -29,9 +30,13 @@ class TXTKnowledge(Knowledge):
knowledge_type(KnowledgeType, optional): knowledge type knowledge_type(KnowledgeType, optional): knowledge type
loader(Any, optional): loader loader(Any, optional): loader
""" """
self._path = file_path super().__init__(
self._type = knowledge_type path=file_path,
self._loader = loader knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
def _load(self) -> List[Document]: def _load(self) -> List[Document]:
"""Load txt document from loader.""" """Load txt document from loader."""
@ -48,6 +53,8 @@ class TXTKnowledge(Knowledge):
else: else:
text = raw_text.decode(result["encoding"]) text = raw_text.decode(result["encoding"])
metadata = {"source": self._path} metadata = {"source": self._path}
if self._metadata:
metadata.update(self._metadata) # type: ignore
return [Document(content=text, metadata=metadata)] return [Document(content=text, metadata=metadata)]
return [Document.langchain2doc(lc_document) for lc_document in documents] return [Document.langchain2doc(lc_document) for lc_document in documents]

View File

@ -26,9 +26,9 @@ class URLKnowledge(Knowledge):
encoding(str, optional): csv encoding encoding(str, optional): csv encoding
loader(Any, optional): loader loader(Any, optional): loader
""" """
self._path = url or None super().__init__(
self._type = knowledge_type path=url, knowledge_type=knowledge_type, loader=loader, **kwargs
self._loader = loader )
self._encoding = encoding self._encoding = encoding
self._source_column = source_column self._source_column = source_column

View File

@ -1,9 +1,10 @@
"""Base retriever module.""" """Base retriever module."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import List from typing import List, Optional
from dbgpt.core import Chunk from dbgpt.core import Chunk
from dbgpt.storage.vector_store.filters import MetadataFilters
class RetrieverStrategy(str, Enum): class RetrieverStrategy(str, Enum):
@ -23,83 +24,112 @@ class RetrieverStrategy(str, Enum):
class BaseRetriever(ABC): class BaseRetriever(ABC):
"""Base retriever.""" """Base retriever."""
def retrieve(self, query: str) -> List[Chunk]: def retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks. """Retrieve knowledge chunks.
Args: Args:
query (str): query text query (str): query text.
filters: (Optional[MetadataFilters]) metadata filters.
Returns: Returns:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
return self._retrieve(query) return self._retrieve(query, filters)
async def aretrieve(self, query: str) -> List[Chunk]: async def aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks. """Retrieve knowledge chunks.
Args: Args:
query (str): async query text query (str): async query text.
filters: (Optional[MetadataFilters]) metadata filters.
Returns: Returns:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
return await self._aretrieve(query) return await self._aretrieve(query, filters)
def retrieve_with_scores(self, query: str, score_threshold: float) -> List[Chunk]: def retrieve_with_scores(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score. """Retrieve knowledge chunks with score.
Args: Args:
query (str): query text query (str): query text.
score_threshold (float): score threshold score_threshold (float): score threshold.
filters: (Optional[MetadataFilters]) metadata filters.
Returns: Returns:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
return self._retrieve_with_score(query, score_threshold) return self._retrieve_with_score(query, score_threshold, filters)
async def aretrieve_with_scores( async def aretrieve_with_scores(
self, query: str, score_threshold: float self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]: ) -> List[Chunk]:
"""Retrieve knowledge chunks with score. """Retrieve knowledge chunks with score.
Args: Args:
query (str): query text query (str): query text
score_threshold (float): score threshold score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Returns: Returns:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
return await self._aretrieve_with_score(query, score_threshold) return await self._aretrieve_with_score(query, score_threshold, filters)
@abstractmethod @abstractmethod
def _retrieve(self, query: str) -> List[Chunk]: def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks. """Retrieve knowledge chunks.
Args: Args:
query (str): query text query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Returns: Returns:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
@abstractmethod @abstractmethod
async def _aretrieve(self, query: str) -> List[Chunk]: async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Async Retrieve knowledge chunks. """Async Retrieve knowledge chunks.
Args: Args:
query (str): query text query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Returns: Returns:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
@abstractmethod @abstractmethod
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]: def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score. """Retrieve knowledge chunks with score.
Args: Args:
query (str): query text query (str): query text
score_threshold (float): score threshold score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Returns: Returns:
List[Chunk]: list of chunks List[Chunk]: list of chunks
@ -107,13 +137,17 @@ class BaseRetriever(ABC):
@abstractmethod @abstractmethod
async def _aretrieve_with_score( async def _aretrieve_with_score(
self, query: str, score_threshold: float self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]: ) -> List[Chunk]:
"""Async Retrieve knowledge chunks with score. """Async Retrieve knowledge chunks with score.
Args: Args:
query (str): query text query (str): query text
score_threshold (float): score threshold score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Returns: Returns:
List[Chunk]: list of chunks List[Chunk]: list of chunks

View File

@ -8,6 +8,7 @@ from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks from dbgpt.util.chat_util import run_async_tasks
@ -93,11 +94,14 @@ class DBSchemaRetriever(BaseRetriever):
self._need_embeddings = True self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k) self._rerank = rerank or DefaultRanker(self._top_k)
def _retrieve(self, query: str) -> List[Chunk]: def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks. """Retrieve knowledge chunks.
Args: Args:
query (str): query text query (str): query text
filters: metadata filters.
Returns: Returns:
List[Chunk]: list of chunks List[Chunk]: list of chunks
@ -105,7 +109,7 @@ class DBSchemaRetriever(BaseRetriever):
if self._need_embeddings: if self._need_embeddings:
queries = [query] queries = [query]
candidates = [ candidates = [
self._vector_store_connector.similar_search(query, self._top_k) self._vector_store_connector.similar_search(query, self._top_k, filters)
for query in queries for query in queries
] ]
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates)) return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
@ -115,30 +119,39 @@ class DBSchemaRetriever(BaseRetriever):
table_summaries = _parse_db_summary(self._connector) table_summaries = _parse_db_summary(self._connector)
return [Chunk(content=table_summary) for table_summary in table_summaries] return [Chunk(content=table_summary) for table_summary in table_summaries]
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]: def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score. """Retrieve knowledge chunks with score.
Args: Args:
query (str): query text query (str): query text
score_threshold (float): score threshold score_threshold (float): score threshold
filters: metadata filters.
Returns: Returns:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
return self._retrieve(query) return self._retrieve(query, filters)
async def _aretrieve(self, query: str) -> List[Chunk]: async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks. """Retrieve knowledge chunks.
Args: Args:
query (str): query text query (str): query text
filters: metadata filters.
Returns: Returns:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
if self._need_embeddings: if self._need_embeddings:
queries = [query] queries = [query]
candidates = [self._similarity_search(query) for query in queries] candidates = [self._similarity_search(query, filters) for query in queries]
result_candidates = await run_async_tasks( result_candidates = await run_async_tasks(
tasks=candidates, concurrency_limit=1 tasks=candidates, concurrency_limit=1
) )
@ -154,22 +167,25 @@ class DBSchemaRetriever(BaseRetriever):
return [Chunk(content=table_summary) for table_summary in table_summaries] return [Chunk(content=table_summary) for table_summary in table_summaries]
async def _aretrieve_with_score( async def _aretrieve_with_score(
self, query: str, score_threshold: float self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]: ) -> List[Chunk]:
"""Retrieve knowledge chunks with score. """Retrieve knowledge chunks with score.
Args: Args:
query (str): query text query (str): query text
score_threshold (float): score threshold score_threshold (float): score threshold
filters: metadata filters.
""" """
return await self._aretrieve(query) return await self._aretrieve(query, filters)
async def _similarity_search(self, query) -> List[Chunk]: async def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search.""" """Similar search."""
return self._vector_store_connector.similar_search( return self._vector_store_connector.similar_search(query, self._top_k, filters)
query,
self._top_k,
)
async def _aparse_db_summary(self) -> List[str]: async def _aparse_db_summary(self) -> List[str]:
"""Similar search.""" """Similar search."""

View File

@ -7,6 +7,7 @@ from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks from dbgpt.util.chat_util import run_async_tasks
from dbgpt.util.tracer import root_tracer from dbgpt.util.tracer import root_tracer
@ -65,37 +66,44 @@ class EmbeddingRetriever(BaseRetriever):
self._vector_store_connector = vector_store_connector self._vector_store_connector = vector_store_connector
self._rerank = rerank or DefaultRanker(self._top_k) self._rerank = rerank or DefaultRanker(self._top_k)
def _retrieve(self, query: str) -> List[Chunk]: def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks. """Retrieve knowledge chunks.
Args: Args:
query (str): query text query (str): query text
filters: metadata filters.
Return: Return:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
queries = [query] queries = [query]
candidates = [ candidates = [
self._vector_store_connector.similar_search(query, self._top_k) self._vector_store_connector.similar_search(query, self._top_k, filters)
for query in queries for query in queries
] ]
res_candidates = cast(List[Chunk], reduce(lambda x, y: x + y, candidates)) res_candidates = cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
return res_candidates return res_candidates
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]: def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score. """Retrieve knowledge chunks with score.
Args: Args:
query (str): query text query (str): query text
score_threshold (float): score threshold score_threshold (float): score threshold
filters: metadata filters.
Return: Return:
List[Chunk]: list of chunks with score List[Chunk]: list of chunks with score
""" """
queries = [query] queries = [query]
candidates_with_score = [ candidates_with_score = [
self._vector_store_connector.similar_search_with_scores( self._vector_store_connector.similar_search_with_scores(
query, self._top_k, score_threshold query, self._top_k, score_threshold, filters
) )
for query in queries for query in queries
] ]
@ -105,37 +113,44 @@ class EmbeddingRetriever(BaseRetriever):
new_candidates_with_score = self._rerank.rank(new_candidates_with_score) new_candidates_with_score = self._rerank.rank(new_candidates_with_score)
return new_candidates_with_score return new_candidates_with_score
async def _aretrieve(self, query: str) -> List[Chunk]: async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks. """Retrieve knowledge chunks.
Args: Args:
query (str): query text query (str): query text.
filters: metadata filters.
Return: Return:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
queries = [query] queries = [query]
if self._query_rewrite: if self._query_rewrite:
candidates_tasks = [self._similarity_search(query) for query in queries] candidates_tasks = [
self._similarity_search(query, filters) for query in queries
]
chunks = await self._run_async_tasks(candidates_tasks) chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks]) context = "\n".join([chunk.content for chunk in chunks])
new_queries = await self._query_rewrite.rewrite( new_queries = await self._query_rewrite.rewrite(
origin_query=query, context=context, nums=1 origin_query=query, context=context, nums=1
) )
queries.extend(new_queries) queries.extend(new_queries)
candidates = [self._similarity_search(query) for query in queries] candidates = [self._similarity_search(query, filters) for query in queries]
new_candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1) new_candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
return new_candidates return new_candidates
async def _aretrieve_with_score( async def _aretrieve_with_score(
self, query: str, score_threshold: float self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]: ) -> List[Chunk]:
"""Retrieve knowledge chunks with score. """Retrieve knowledge chunks with score.
Args: Args:
query (str): query text query (str): query text
score_threshold (float): score threshold score_threshold (float): score threshold
filters: metadata filters.
Return: Return:
List[Chunk]: list of chunks with score List[Chunk]: list of chunks with score
""" """
@ -145,7 +160,9 @@ class EmbeddingRetriever(BaseRetriever):
"EmbeddingRetriever.query_rewrite.similarity_search", "EmbeddingRetriever.query_rewrite.similarity_search",
metadata={"query": query, "score_threshold": score_threshold}, metadata={"query": query, "score_threshold": score_threshold},
): ):
candidates_tasks = [self._similarity_search(query) for query in queries] candidates_tasks = [
self._similarity_search(query, filters) for query in queries
]
chunks = await self._run_async_tasks(candidates_tasks) chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks]) context = "\n".join([chunk.content for chunk in chunks])
with root_tracer.start_span( with root_tracer.start_span(
@ -162,7 +179,7 @@ class EmbeddingRetriever(BaseRetriever):
metadata={"query": query, "score_threshold": score_threshold}, metadata={"query": query, "score_threshold": score_threshold},
): ):
candidates_with_score = [ candidates_with_score = [
self._similarity_search_with_score(query, score_threshold) self._similarity_search_with_score(query, score_threshold, filters)
for query in queries for query in queries
] ]
res_candidates_with_score = await run_async_tasks( res_candidates_with_score = await run_async_tasks(
@ -183,12 +200,11 @@ class EmbeddingRetriever(BaseRetriever):
new_candidates_with_score = self._rerank.rank(new_candidates_with_score) new_candidates_with_score = self._rerank.rank(new_candidates_with_score)
return new_candidates_with_score return new_candidates_with_score
async def _similarity_search(self, query) -> List[Chunk]: async def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search.""" """Similar search."""
return self._vector_store_connector.similar_search( return self._vector_store_connector.similar_search(query, self._top_k, filters)
query,
self._top_k,
)
async def _run_async_tasks(self, tasks) -> List[Chunk]: async def _run_async_tasks(self, tasks) -> List[Chunk]:
"""Run async tasks.""" """Run async tasks."""
@ -197,9 +213,9 @@ class EmbeddingRetriever(BaseRetriever):
return cast(List[Chunk], candidates) return cast(List[Chunk], candidates)
async def _similarity_search_with_score( async def _similarity_search_with_score(
self, query, score_threshold self, query, score_threshold, filters: Optional[MetadataFilters] = None
) -> List[Chunk]: ) -> List[Chunk]:
"""Similar search with score.""" """Similar search with score."""
return self._vector_store_connector.similar_search_with_scores( return self._vector_store_connector.similar_search_with_scores(
query, self._top_k, score_threshold query, self._top_k, score_threshold, filters
) )

View File

@ -36,5 +36,4 @@ def test_retrieve(query, top_k, mock_vector_store_connector, embedding_retriever
retrieved_chunks = embedding_retriever._retrieve(query) retrieved_chunks = embedding_retriever._retrieve(query)
mock_vector_store_connector.similar_search.assert_called_once_with(query, top_k)
assert len(retrieved_chunks) == top_k assert len(retrieved_chunks) == top_k

View File

@ -457,12 +457,12 @@ class MarkdownHeaderTextSplitter(TextSplitter):
# Output line-by-line or aggregated into chunks w/ common headers # Output line-by-line or aggregated into chunks w/ common headers
if headers_to_split_on is None: if headers_to_split_on is None:
headers_to_split_on = [ headers_to_split_on = [
("#", "Header 1"), ("#", "Header1"),
("##", "Header 2"), ("##", "Header2"),
("###", "Header 3"), ("###", "Header3"),
("####", "Header 4"), ("####", "Header4"),
("#####", "Header 5"), ("#####", "Header5"),
("######", "Header 6"), ("######", "Header6"),
] ]
if filters is None: if filters is None:
filters = [] filters = []

View File

@ -7,6 +7,7 @@ from dbgpt.core import Chunk
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
CFG = Config() CFG = Config()
@ -46,52 +47,81 @@ class KnowledgeSpaceRetriever(BaseRetriever):
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create() ).create()
def _retrieve(self, query: str) -> List[Chunk]: def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks. """Retrieve knowledge chunks.
Args: Args:
query (str): query text query (str): query text.
filters: (Optional[MetadataFilters]) metadata filters.
Return: Return:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
candidates = self._vector_store_connector.similar_search( candidates = self._vector_store_connector.similar_search(
doc=query, topk=self._top_k doc=query, topk=self._top_k, filters=filters
) )
return candidates return candidates
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]: def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score. """Retrieve knowledge chunks with score.
Args: Args:
query (str): query text query (str): query text
score_threshold (float): score threshold score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Return: Return:
List[Chunk]: list of chunks with score List[Chunk]: list of chunks with score
""" """
candidates_with_score = self._vector_store_connector.similar_search_with_scores( candidates_with_score = self._vector_store_connector.similar_search_with_scores(
doc=query, topk=self._top_k, score_threshold=score_threshold doc=query,
topk=self._top_k,
score_threshold=score_threshold,
filters=filters,
) )
return candidates_with_score return candidates_with_score
async def _aretrieve(self, query: str) -> List[Chunk]: async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks. """Retrieve knowledge chunks.
Args: Args:
query (str): query text query (str): query text.
filters: (Optional[MetadataFilters]) metadata filters.
Return: Return:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
candidates = await blocking_func_to_async(self._executor, self._retrieve, query) candidates = await blocking_func_to_async(
self._executor, self._retrieve, query, filters
)
return candidates return candidates
async def _aretrieve_with_score( async def _aretrieve_with_score(
self, query: str, score_threshold: float self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]: ) -> List[Chunk]:
"""Retrieve knowledge chunks with score. """Retrieve knowledge chunks with score.
Args: Args:
query (str): query text query (str): query text.
score_threshold (float): score threshold score_threshold (float): score threshold.
filters: (Optional[MetadataFilters]) metadata filters.
Return: Return:
List[Chunk]: list of chunks with score List[Chunk]: list of chunks with score.
""" """
candidates_with_score = await blocking_func_to_async( candidates_with_score = await blocking_func_to_async(
self._executor, self._retrieve_with_score, query, score_threshold self._executor, self._retrieve_with_score, query, score_threshold, filters
) )
return candidates_with_score return candidates_with_score

View File

@ -4,11 +4,12 @@ import math
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional from typing import Any, List, Optional
from dbgpt._private.pydantic import BaseModel, Field from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import Chunk, Embeddings from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter from dbgpt.core.awel.flow import Parameter
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.i18n_utils import _ from dbgpt.util.i18n_utils import _
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -176,13 +177,15 @@ class VectorStoreBase(ABC):
return ids return ids
@abstractmethod @abstractmethod
def similar_search(self, text: str, topk: int) -> List[Chunk]: def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search in vector database. """Similar search in vector database.
Args: Args:
text(str): The query text. text(str): The query text.
topk(int): The number of similar documents to return. topk(int): The number of similar documents to return.
filters(Optional[MetadataFilters]): metadata filters.
Return: Return:
List[Chunk]: The similar documents. List[Chunk]: The similar documents.
""" """
@ -190,7 +193,11 @@ class VectorStoreBase(ABC):
@abstractmethod @abstractmethod
def similar_search_with_scores( def similar_search_with_scores(
self, text, topk, score_threshold: float self,
text,
topk,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]: ) -> List[Chunk]:
"""Similar search with scores in vector database. """Similar search with scores in vector database.
@ -199,6 +206,7 @@ class VectorStoreBase(ABC):
topk(int): The number of similar documents to return. topk(int): The number of similar documents to return.
score_threshold(int): score_threshold: Optional, a floating point value score_threshold(int): score_threshold: Optional, a floating point value
between 0 to 1 between 0 to 1
filters(Optional[MetadataFilters]): metadata filters.
Return: Return:
List[Chunk]: The similar documents. List[Chunk]: The similar documents.
""" """
@ -223,6 +231,15 @@ class VectorStoreBase(ABC):
Args: Args:
vector_name(str): The name of vector to delete. vector_name(str): The name of vector to delete.
""" """
pass
def convert_metadata_filters(self, filters: MetadataFilters) -> Any:
"""Convert metadata filters to vector store filters.
Args:
filters: (Optional[MetadataFilters]) metadata filters.
"""
raise NotImplementedError
def _normalization_vectors(self, vectors): def _normalization_vectors(self, vectors):
"""Return L2-normalization vectors to scale[0,1]. """Return L2-normalization vectors to scale[0,1].

View File

@ -1,7 +1,7 @@
"""Chroma vector store.""" """Chroma vector store."""
import logging import logging
import os import os
from typing import Any, List from typing import List, Optional
from chromadb import PersistentClient from chromadb import PersistentClient
from chromadb.config import Settings from chromadb.config import Settings
@ -13,6 +13,7 @@ from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.util.i18n_utils import _ from dbgpt.util.i18n_utils import _
from .base import _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig from .base import _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig
from .filters import FilterOperator, MetadataFilters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -86,16 +87,23 @@ class ChromaStore(VectorStoreBase):
collection_metadata=collection_metadata, collection_metadata=collection_metadata,
) )
def similar_search(self, text, topk, **kwargs: Any) -> List[Chunk]: def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Search similar documents.""" """Search similar documents."""
logger.info("ChromaStore similar search") logger.info("ChromaStore similar search")
lc_documents = self.vector_store_client.similarity_search(text, topk, **kwargs) where_filters = self.convert_metadata_filters(filters) if filters else None
lc_documents = self.vector_store_client.similarity_search(
text, topk, filter=where_filters
)
return [ return [
Chunk(content=doc.page_content, metadata=doc.metadata) Chunk(content=doc.page_content, metadata=doc.metadata)
for doc in lc_documents for doc in lc_documents
] ]
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]: def similar_search_with_scores(
self, text, topk, score_threshold, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Search similar documents with scores. """Search similar documents with scores.
Chroma similar_search_with_score. Chroma similar_search_with_score.
@ -106,11 +114,16 @@ class ChromaStore(VectorStoreBase):
score_threshold(float): score_threshold: Optional, a floating point value score_threshold(float): score_threshold: Optional, a floating point value
between 0 to 1 to filter the resulting set of retrieved docs,0 is between 0 to 1 to filter the resulting set of retrieved docs,0 is
dissimilar, 1 is most similar. dissimilar, 1 is most similar.
filters(MetadataFilters): metadata filters, defaults to None
""" """
logger.info("ChromaStore similar search with scores") logger.info("ChromaStore similar search with scores")
where_filters = self.convert_metadata_filters(filters) if filters else None
docs_and_scores = ( docs_and_scores = (
self.vector_store_client.similarity_search_with_relevance_scores( self.vector_store_client.similarity_search_with_relevance_scores(
query=text, k=topk, score_threshold=score_threshold query=text,
k=topk,
score_threshold=score_threshold,
filter=where_filters,
) )
) )
return [ return [
@ -152,10 +165,71 @@ class ChromaStore(VectorStoreBase):
collection = self.vector_store_client._collection collection = self.vector_store_client._collection
collection.delete(ids=ids) collection.delete(ids=ids)
def convert_metadata_filters(
self,
filters: MetadataFilters,
) -> dict:
"""Convert metadata filters to Chroma filters.
Args:
filters(MetadataFilters): metadata filters.
Returns:
dict: Chroma filters.
"""
where_filters = {}
filters_list = []
condition = filters.condition
chroma_condition = f"${condition}"
if filters.filters:
for filter in filters.filters:
if filter.operator:
filters_list.append(
{
filter.key: {
_convert_chroma_filter_operator(
filter.operator
): filter.value
}
}
)
else:
filters_list.append({filter.key: filter.value}) # type: ignore
if len(filters_list) == 1:
return filters_list[0]
elif len(filters_list) > 1:
where_filters[chroma_condition] = filters_list
return where_filters
def _clean_persist_folder(self): def _clean_persist_folder(self):
"""Clean persist folder."""
for root, dirs, files in os.walk(self.persist_dir, topdown=False): for root, dirs, files in os.walk(self.persist_dir, topdown=False):
for name in files: for name in files:
os.remove(os.path.join(root, name)) os.remove(os.path.join(root, name))
for name in dirs: for name in dirs:
os.rmdir(os.path.join(root, name)) os.rmdir(os.path.join(root, name))
os.rmdir(self.persist_dir) os.rmdir(self.persist_dir)
def _convert_chroma_filter_operator(operator: str) -> str:
"""Convert operator to Chroma where operator.
Args:
operator(str): operator.
Returns:
str: Chroma where operator.
"""
if operator == FilterOperator.EQ:
return "$eq"
elif operator == FilterOperator.NE:
return "$ne"
elif operator == FilterOperator.GT:
return "$gt"
elif operator == FilterOperator.LT:
return "$lt"
elif operator == FilterOperator.GTE:
return "$gte"
elif operator == FilterOperator.LTE:
return "$lte"
else:
raise ValueError(f"Chroma Where operator {operator} not supported")

View File

@ -13,6 +13,7 @@ from dbgpt.core.awel.flow import (
) )
from dbgpt.storage import vector_store from dbgpt.storage import vector_store
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.i18n_utils import _ from dbgpt.util.i18n_utils import _
connector: Dict[str, Type] = {} connector: Dict[str, Type] = {}
@ -128,23 +129,29 @@ class VectorStoreConnector:
max_threads, max_threads,
) )
def similar_search(self, doc: str, topk: int) -> List[Chunk]: def similar_search(
self, doc: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search in vector database. """Similar search in vector database.
Args: Args:
- doc: query text - doc: query text
- topk: topk - topk: topk
- filters: metadata filters.
Return: Return:
- chunks: chunks. - chunks: chunks.
""" """
return self.client.similar_search(doc, topk) return self.client.similar_search(doc, topk, filters)
def similar_search_with_scores( def similar_search_with_scores(
self, doc: str, topk: int, score_threshold: float self,
doc: str,
topk: int,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]: ) -> List[Chunk]:
"""Similar search with scores in vector database. """Similar_search_with_score in vector database.
similar_search_with_score in vector database..
Return docs and relevance scores in the range [0, 1]. Return docs and relevance scores in the range [0, 1].
Args: Args:
@ -153,10 +160,13 @@ class VectorStoreConnector:
score_threshold(float): score_threshold: Optional, a floating point value score_threshold(float): score_threshold: Optional, a floating point value
between 0 to 1 to filter the resulting set of retrieved docs,0 is between 0 to 1 to filter the resulting set of retrieved docs,0 is
dissimilar, 1 is most similar. dissimilar, 1 is most similar.
filters: metadata filters.
Return: Return:
- chunks: chunks. - chunks: Return docs and relevance scores in the range [0, 1].
""" """
return self.client.similar_search_with_scores(doc, topk, score_threshold) return self.client.similar_search_with_scores(
doc, topk, score_threshold, filters
)
@property @property
def vector_store_config(self) -> VectorStoreConfig: def vector_store_config(self) -> VectorStoreConfig:

View File

@ -0,0 +1,56 @@
"""Vector Store Meta data filters."""
from enum import Enum
from typing import List, Union
from pydantic import BaseModel, Field
class FilterOperator(str, Enum):
"""Meta data filter operator."""
EQ = "=="
GT = ">"
LT = "<"
NE = "!="
GTE = ">="
LTE = "<="
IN = "in"
NIN = "nin"
EXISTS = "exists"
class FilterCondition(str, Enum):
"""Vector Store Meta data filter conditions."""
AND = "and"
OR = "or"
class MetadataFilter(BaseModel):
"""Meta data filter."""
key: str = Field(
...,
description="The key of metadata to filter.",
)
operator: FilterOperator = Field(
default=FilterOperator.EQ,
description="The operator of metadata filter.",
)
value: Union[str, int, float, List[str], List[int], List[float]] = Field(
...,
description="The value of metadata to filter.",
)
class MetadataFilters(BaseModel):
"""Meta data filters."""
condition: FilterCondition = Field(
default=FilterCondition.AND,
description="The condition of metadata filters.",
)
filters: List[MetadataFilter] = Field(
...,
description="The metadata filters.",
)

View File

@ -14,6 +14,7 @@ from dbgpt.storage.vector_store.base import (
VectorStoreBase, VectorStoreBase,
VectorStoreConfig, VectorStoreConfig,
) )
from dbgpt.storage.vector_store.filters import FilterOperator, MetadataFilters
from dbgpt.util import string_utils from dbgpt.util import string_utils
from dbgpt.util.i18n_utils import _ from dbgpt.util.i18n_utils import _
@ -206,6 +207,7 @@ class MilvusStore(VectorStoreBase):
self.vector_field = milvus_vector_config.get("embedding_field") or "vector" self.vector_field = milvus_vector_config.get("embedding_field") or "vector"
self.text_field = milvus_vector_config.get("text_field") or "content" self.text_field = milvus_vector_config.get("text_field") or "content"
self.metadata_field = milvus_vector_config.get("metadata_field") or "metadata" self.metadata_field = milvus_vector_config.get("metadata_field") or "metadata"
self.props_field = milvus_vector_config.get("props_field") or "props_field"
if (self.username is None) != (self.password is None): if (self.username is None) != (self.password is None):
raise ValueError( raise ValueError(
@ -284,6 +286,7 @@ class MilvusStore(VectorStoreBase):
vector_field = self.vector_field vector_field = self.vector_field
text_field = self.text_field text_field = self.text_field
metadata_field = self.metadata_field metadata_field = self.metadata_field
props_field = self.props_field
# self.text_field = text_field # self.text_field = text_field
collection_name = vector_name collection_name = vector_name
fields = [] fields = []
@ -300,6 +303,7 @@ class MilvusStore(VectorStoreBase):
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535)) fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
fields.append(FieldSchema(props_field, DataType.JSON))
schema = CollectionSchema(fields) schema = CollectionSchema(fields)
# Create the collection # Create the collection
collection = Collection(collection_name, schema) collection = Collection(collection_name, schema)
@ -346,6 +350,7 @@ class MilvusStore(VectorStoreBase):
for d in metadatas: for d in metadatas:
# for key, value in d.items(): # for key, value in d.items():
insert_dict.setdefault("metadata", []).append(json.dumps(d)) insert_dict.setdefault("metadata", []).append(json.dumps(d))
insert_dict.setdefault("props_field", []).append(d)
# Convert dict to list of lists for insertion # Convert dict to list of lists for insertion
insert_list = [insert_dict[x] for x in self.fields] insert_list = [insert_dict[x] for x in self.fields]
# Insert into the collection. # Insert into the collection.
@ -368,7 +373,9 @@ class MilvusStore(VectorStoreBase):
doc_ids = [str(doc_id) for doc_id in doc_ids] doc_ids = [str(doc_id) for doc_id in doc_ids]
return doc_ids return doc_ids
def similar_search(self, text, topk) -> List[Chunk]: def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Perform a search on a query string and return results.""" """Perform a search on a query string and return results."""
from pymilvus import Collection, DataType from pymilvus import Collection, DataType
@ -383,7 +390,9 @@ class MilvusStore(VectorStoreBase):
self.primary_field = x.name self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name self.vector_field = x.name
_, docs_and_scores = self._search(text, topk) # convert to milvus expr filter.
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
_, docs_and_scores = self._search(text, topk, expr=milvus_filter_expr)
return [ return [
Chunk( Chunk(
@ -393,7 +402,13 @@ class MilvusStore(VectorStoreBase):
for doc, _, _ in docs_and_scores for doc, _, _ in docs_and_scores
] ]
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]: def similar_search_with_scores(
self,
text: str,
topk: int,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Perform a search on a query string and return results with score. """Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus For more information about the search parameters, take a look at the pymilvus
@ -401,15 +416,10 @@ class MilvusStore(VectorStoreBase):
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args: Args:
embedding (List[float]): The embedding vector being searched. text (str): The query text.
k (int, optional): The amount of results to return. Defaults to 4. topk (int): The number of similar documents to return.
param (dict): The search params for the specified index. score_threshold (float): Optional, a floating point value between 0 to 1.
Defaults to None. filters (Optional[MetadataFilters]): Optional, metadata filters.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns: Returns:
List[Tuple[Document, float]]: Result doc and score. List[Tuple[Document, float]]: Result doc and score.
""" """
@ -427,7 +437,11 @@ class MilvusStore(VectorStoreBase):
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name self.vector_field = x.name
_, docs_and_scores = self._search(text, topk) # convert to milvus expr filter.
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
_, docs_and_scores = self._search(
query=text, topk=topk, expr=milvus_filter_expr
)
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores): if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
logger.warning( logger.warning(
"similarity score need between" f" 0 and 1, got {docs_and_scores}" "similarity score need between" f" 0 and 1, got {docs_and_scores}"
@ -462,6 +476,20 @@ class MilvusStore(VectorStoreBase):
timeout: Optional[int] = None, timeout: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
): ):
"""Search in vector database.
Args:
query: query text.
k: topk.
param: search params.
expr: search expr.
partition_names: partition names.
round_decimal: round decimal.
timeout: timeout.
**kwargs: kwargs.
Returns:
Tuple[Document, float, int]: Result doc and score.
"""
self.col.load() self.col.load()
# use default index params. # use default index params.
if param is None: if param is None:
@ -495,7 +523,9 @@ class MilvusStore(VectorStoreBase):
result.id, result.id,
) )
) )
if len(ret) == 0:
logger.warning("No relevant docs were retrieved.")
return None, []
return ret[0], ret return ret[0], ret
def vector_name_exists(self): def vector_name_exists(self):
@ -523,6 +553,40 @@ class MilvusStore(VectorStoreBase):
logger.info(f"begin delete milvus ids: {ids}") logger.info(f"begin delete milvus ids: {ids}")
delete_ids = ids.split(",") delete_ids = ids.split(",")
doc_ids = [int(doc_id) for doc_id in delete_ids] doc_ids = [int(doc_id) for doc_id in delete_ids]
delet_expr = f"{self.primary_field} in {doc_ids}" delete_expr = f"{self.primary_field} in {doc_ids}"
self.col.delete(delet_expr) self.col.delete(delete_expr)
return True return True
def convert_metadata_filters(self, filters: MetadataFilters) -> str:
"""Convert filter to milvus filters.
Args:
- filters: metadata filters.
Returns:
- metadata_filters: metadata filters.
"""
metadata_filters = []
for metadata_filter in filters.filters:
if isinstance(metadata_filter.value, str):
expr = (
f"{self.props_field}['{metadata_filter.key}'] "
f"{FilterOperator.EQ} '{metadata_filter.value}'"
)
metadata_filters.append(expr)
elif isinstance(metadata_filter.value, List):
expr = (
f"{self.props_field}['{metadata_filter.key}'] "
f"{FilterOperator.IN} {metadata_filter.value}"
)
metadata_filters.append(expr)
else:
expr = (
f"{self.props_field}['{metadata_filter.key}'] "
f"{FilterOperator.EQ} {str(metadata_filter.value)}"
)
metadata_filters.append(expr)
if len(metadata_filters) > 1:
metadata_filter_expr = f" {filters.condition} ".join(metadata_filters)
else:
metadata_filter_expr = metadata_filters[0]
return metadata_filter_expr

View File

@ -1,6 +1,6 @@
"""Postgres vector store.""" """Postgres vector store."""
import logging import logging
from typing import Any, List from typing import List, Optional
from dbgpt._private.pydantic import Field from dbgpt._private.pydantic import Field
from dbgpt.core import Chunk from dbgpt.core import Chunk
@ -10,6 +10,7 @@ from dbgpt.storage.vector_store.base import (
VectorStoreBase, VectorStoreBase,
VectorStoreConfig, VectorStoreConfig,
) )
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.i18n_utils import _ from dbgpt.util.i18n_utils import _
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -70,9 +71,11 @@ class PGVectorStore(VectorStoreBase):
connection_string=self.connection_string, connection_string=self.connection_string,
) )
def similar_search(self, text: str, topk: int, **kwargs: Any) -> List[Chunk]: def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Perform similar search in PGVector.""" """Perform similar search in PGVector."""
return self.vector_store_client.similarity_search(text, topk) return self.vector_store_client.similarity_search(text, topk, filters)
def vector_name_exists(self) -> bool: def vector_name_exists(self) -> bool:
"""Check if vector name exists.""" """Check if vector name exists."""

View File

@ -1,7 +1,7 @@
"""Weaviate vector store.""" """Weaviate vector store."""
import logging import logging
import os import os
from typing import List from typing import List, Optional
from dbgpt._private.pydantic import Field from dbgpt._private.pydantic import Field
from dbgpt.core import Chunk from dbgpt.core import Chunk
@ -9,6 +9,7 @@ from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.util.i18n_utils import _ from dbgpt.util.i18n_utils import _
from .base import _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig from .base import _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig
from .filters import MetadataFilters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -80,7 +81,9 @@ class WeaviateStore(VectorStoreBase):
self.vector_store_client = weaviate.Client(self.weaviate_url) self.vector_store_client = weaviate.Client(self.weaviate_url)
def similar_search(self, text: str, topk: int) -> List[Chunk]: def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Perform similar search in Weaviate.""" """Perform similar search in Weaviate."""
logger.info("Weaviate similar search") logger.info("Weaviate similar search")
# nearText = { # nearText = {

View File

@ -0,0 +1,60 @@
"""Rag Metadata Properties filter example.
pre-requirements:
make sure you have set your embedding model path in your example code.
Examples:
..code-block:: shell
python examples/rag/metadata_filter_example.py
"""
import asyncio
import os
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
from dbgpt.rag import ChunkParameters
from dbgpt.rag.assembler import EmbeddingAssembler
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.rag.knowledge import KnowledgeFactory
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="example_metadata_filter_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
async def main():
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
knowledge = KnowledgeFactory.from_file_path(file_path)
vector_connector = _create_vector_connector()
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_MARKDOWN_HEADER")
# get embedding assembler
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
vector_store_connector=vector_connector,
)
assembler.persist()
# get embeddings retriever
retriever = assembler.as_retriever(3)
# create metadata filter
metadata_filter = MetadataFilter(key="Header2", value="AWEL Design")
filters = MetadataFilters(filters=[metadata_filter])
chunks = await retriever.aretrieve_with_scores(
"what is awel talk about", 0.0, filters
)
print(f"embedding rag example results:{chunks}")
if __name__ == "__main__":
asyncio.run(main())