feat(RAG):add metadata properties filters (#1395)

This commit is contained in:
Aries-ckt 2024-04-10 14:33:24 +08:00 committed by GitHub
parent 0f2b46da62
commit 37e7c0151b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 619 additions and 166 deletions

View File

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from dbgpt.core import Document
from dbgpt.rag.text_splitter.text_splitter import (
@ -147,16 +147,18 @@ class Knowledge(ABC):
self,
path: Optional[str] = None,
knowledge_type: Optional[KnowledgeType] = None,
data_loader: Optional[Any] = None,
loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any,
) -> None:
"""Initialize with Knowledge arguments."""
self._path = path
self._type = knowledge_type
self._data_loader = data_loader
self._loader = loader
self._metadata = metadata
def load(self) -> List[Document]:
"""Load knowledge from data_loader."""
"""Load knowledge from data loader."""
documents = self._load()
return self._postprocess(documents)
@ -171,12 +173,12 @@ class Knowledge(ABC):
return None
def _postprocess(self, docs: List[Document]) -> List[Document]:
"""Post process knowledge from data_loader."""
"""Post process knowledge from data loader."""
return docs
@abstractmethod
def _load(self) -> List[Document]:
"""Preprocess knowledge from data_loader."""
"""Preprocess knowledge from data loader."""
@classmethod
def support_chunk_strategy(cls) -> List[ChunkStrategy]:

View File

@ -1,6 +1,6 @@
"""CSV Knowledge."""
import csv
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import Document
from dbgpt.rag.knowledge.base import (
@ -21,6 +21,7 @@ class CSVKnowledge(Knowledge):
source_column: Optional[str] = None,
encoding: Optional[str] = "utf-8",
loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any,
) -> None:
"""Create CSV Knowledge with Knowledge arguments.
@ -32,9 +33,13 @@ class CSVKnowledge(Knowledge):
encoding(str, optional): csv encoding
loader(Any, optional): loader
"""
self._path = file_path
self._type = knowledge_type
self._loader = loader
super().__init__(
path=file_path,
knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
self._encoding = encoding
self._source_column = source_column
@ -67,6 +72,8 @@ class CSVKnowledge(Knowledge):
f"file."
)
metadata = {"source": source, "row": i}
if self._metadata:
metadata.update(self._metadata) # type: ignore
doc = Document(content=content, metadata=metadata)
docs.append(doc)

View File

@ -1,5 +1,5 @@
"""Datasource Knowledge."""
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import Document
from dbgpt.datasource import BaseConnector
@ -16,26 +16,29 @@ class DatasourceKnowledge(Knowledge):
connector: BaseConnector,
summary_template: str = "{table_name}({columns})",
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any,
) -> None:
"""Create Datasource Knowledge with Knowledge arguments.
Args:
path(str, optional): file path
connector(BaseConnector): connector
summary_template(str, optional): summary template
knowledge_type(KnowledgeType, optional): knowledge type
data_loader(Any, optional): loader
metadata(Dict[str, Union[str, List[str]], optional): metadata
"""
self._connector = connector
self._summary_template = summary_template
super().__init__(knowledge_type=knowledge_type, **kwargs)
super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs)
def _load(self) -> List[Document]:
"""Load datasource document from data_loader."""
docs = []
for table_summary in _parse_db_summary(self._connector, self._summary_template):
docs.append(
Document(content=table_summary, metadata={"source": "database"})
)
metadata = {"source": "database"}
if self._metadata:
metadata.update(self._metadata) # type: ignore
docs.append(Document(content=table_summary, metadata=metadata))
return docs
@classmethod

View File

@ -1,5 +1,5 @@
"""Docx Knowledge."""
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union
import docx
@ -21,6 +21,7 @@ class DocxKnowledge(Knowledge):
knowledge_type: Any = KnowledgeType.DOCUMENT,
encoding: Optional[str] = "utf-8",
loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any,
) -> None:
"""Create Docx Knowledge with Knowledge arguments.
@ -31,9 +32,13 @@ class DocxKnowledge(Knowledge):
encoding(str, optional): csv encoding
loader(Any, optional): loader
"""
self._path = file_path
self._type = knowledge_type
self._loader = loader
super().__init__(
path=file_path,
knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
self._encoding = encoding
def _load(self) -> List[Document]:
@ -48,9 +53,10 @@ class DocxKnowledge(Knowledge):
para = doc.paragraphs[i]
text = para.text
content.append(text)
docs.append(
Document(content="\n".join(content), metadata={"source": self._path})
)
metadata = {"source": self._path}
if self._metadata:
metadata.update(self._metadata) # type: ignore
docs.append(Document(content="\n".join(content), metadata=metadata))
return docs
return [Document.langchain2doc(lc_document) for lc_document in documents]

View File

@ -1,5 +1,5 @@
"""Knowledge Factory to create knowledge from file path and url."""
from typing import List, Optional, Type
from typing import Dict, List, Optional, Type, Union
from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
from dbgpt.rag.knowledge.string import StringKnowledge
@ -13,6 +13,7 @@ class KnowledgeFactory:
self,
file_path: Optional[str] = None,
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
):
"""Create Knowledge Factory with file path and knowledge type.
@ -22,18 +23,21 @@ class KnowledgeFactory:
"""
self._file_path = file_path
self._knowledge_type = knowledge_type
self._metadata = metadata
@classmethod
def create(
cls,
datasource: str = "",
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
):
"""Create knowledge from file path, url or text.
Args:
datasource: path of the file to convert
knowledge_type: type of knowledge
metadata: Optional[Dict[str, Union[str, List[str]]]]
Examples:
.. code-block:: python
@ -52,12 +56,16 @@ class KnowledgeFactory:
match knowledge_type:
case KnowledgeType.DOCUMENT:
return cls.from_file_path(
file_path=datasource, knowledge_type=knowledge_type
file_path=datasource,
knowledge_type=knowledge_type,
metadata=metadata,
)
case KnowledgeType.URL:
return cls.from_url(url=datasource, knowledge_type=knowledge_type)
case KnowledgeType.TEXT:
return cls.from_text(text=datasource, knowledge_type=knowledge_type)
return cls.from_text(
text=datasource, knowledge_type=knowledge_type, metadata=metadata
)
case _:
raise Exception(f"Unsupported knowledge type '{knowledge_type}'")
@ -66,6 +74,7 @@ class KnowledgeFactory:
cls,
file_path: str = "",
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
) -> Knowledge:
"""Create knowledge from path.
@ -82,10 +91,11 @@ class KnowledgeFactory:
datasource="path/to/document.pdf",
knowledge_type=KnowledgeType.DOCUMENT,
)
"""
factory = cls(file_path=file_path, knowledge_type=knowledge_type)
return factory._select_document_knowledge(
file_path=file_path, knowledge_type=knowledge_type
file_path=file_path, knowledge_type=knowledge_type, metadata=metadata
)
@staticmethod
@ -117,6 +127,7 @@ class KnowledgeFactory:
def from_text(
text: str = "",
knowledge_type: KnowledgeType = KnowledgeType.TEXT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
) -> Knowledge:
"""Create knowledge from text.
@ -127,6 +138,7 @@ class KnowledgeFactory:
return StringKnowledge(
text=text,
knowledge_type=knowledge_type,
metadata=metadata,
)
def _select_document_knowledge(self, **kwargs):

View File

@ -1,5 +1,5 @@
"""HTML Knowledge."""
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union
import chardet
@ -20,6 +20,7 @@ class HTMLKnowledge(Knowledge):
file_path: Optional[str] = None,
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any,
) -> None:
"""Create HTML Knowledge with Knowledge arguments.
@ -29,9 +30,13 @@ class HTMLKnowledge(Knowledge):
knowledge_type(KnowledgeType, optional): knowledge type
loader(Any, optional): loader
"""
self._path = file_path
self._type = knowledge_type
self._loader = loader
super().__init__(
path=file_path,
knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
def _load(self) -> List[Document]:
"""Load html document from loader."""
@ -48,6 +53,8 @@ class HTMLKnowledge(Knowledge):
else:
text = raw_text.decode(result["encoding"])
metadata = {"source": self._path}
if self._metadata:
metadata.update(self._metadata) # type: ignore
return [Document(content=text, metadata=metadata)]
return [Document.langchain2doc(lc_document) for lc_document in documents]

View File

@ -1,5 +1,5 @@
"""Markdown Knowledge."""
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import Document
from dbgpt.rag.knowledge.base import (
@ -19,6 +19,7 @@ class MarkdownKnowledge(Knowledge):
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
encoding: Optional[str] = "utf-8",
loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any,
) -> None:
"""Create Markdown Knowledge with Knowledge arguments.
@ -29,9 +30,13 @@ class MarkdownKnowledge(Knowledge):
encoding(str, optional): csv encoding
loader(Any, optional): loader
"""
self._path = file_path
self._type = knowledge_type
self._loader = loader
super().__init__(
path=file_path,
knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
self._encoding = encoding
def _load(self) -> List[Document]:
@ -44,6 +49,8 @@ class MarkdownKnowledge(Knowledge):
with open(self._path, encoding=self._encoding, errors="ignore") as f:
markdown_text = f.read()
metadata = {"source": self._path}
if self._metadata:
metadata.update(self._metadata) # type: ignore
documents = [Document(content=markdown_text, metadata=metadata)]
return documents
return [Document.langchain2doc(lc_document) for lc_document in documents]

View File

@ -1,5 +1,5 @@
"""PDF Knowledge."""
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import Document
from dbgpt.rag.knowledge.base import (
@ -19,6 +19,7 @@ class PDFKnowledge(Knowledge):
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
loader: Optional[Any] = None,
language: Optional[str] = "zh",
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any,
) -> None:
"""Create PDF Knowledge with Knowledge arguments.
@ -29,9 +30,13 @@ class PDFKnowledge(Knowledge):
loader(Any, optional): loader
language(str, optional): language
"""
self._path = file_path
self._type = knowledge_type
self._loader = loader
super().__init__(
path=file_path,
knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
self._language = language
def _load(self) -> List[Document]:
@ -65,6 +70,8 @@ class PDFKnowledge(Knowledge):
page = "\n".join(cleaned_lines)
# cleaned_pages.append(page)
metadata = {"source": self._path, "page": page_num}
if self._metadata:
metadata.update(self._metadata) # type: ignore
# text = "\f".join(cleaned_pages)
document = Document(content=page, metadata=metadata)
documents.append(document)

View File

@ -1,5 +1,5 @@
"""PPTX Knowledge."""
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import Document
from dbgpt.rag.knowledge.base import (
@ -19,6 +19,7 @@ class PPTXKnowledge(Knowledge):
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
loader: Optional[Any] = None,
language: Optional[str] = "zh",
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any,
) -> None:
"""Create PPTX knowledge with PDF Knowledge arguments.
@ -28,9 +29,13 @@ class PPTXKnowledge(Knowledge):
knowledge_type:(KnowledgeType) knowledge type
loader:(Optional[Any]) loader
"""
self._path = file_path
self._type = knowledge_type
self._loader = loader
super().__init__(
path=file_path,
knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
self._language = language
def _load(self) -> List[Document]:
@ -47,9 +52,10 @@ class PPTXKnowledge(Knowledge):
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text:
content += shape.text
docs.append(
Document(content=content, metadata={"source": slide.slide_id})
)
metadata = {"source": self._path}
if self._metadata:
metadata.update(self._metadata) # type: ignore
docs.append(Document(content=content, metadata=metadata))
return docs
return [Document.langchain2doc(lc_document) for lc_document in documents]

View File

@ -1,5 +1,5 @@
"""String Knowledge."""
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import Document
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType
@ -14,6 +14,7 @@ class StringKnowledge(Knowledge):
knowledge_type: KnowledgeType = KnowledgeType.TEXT,
encoding: Optional[str] = "utf-8",
loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any,
) -> None:
"""Create String knowledge parameters.
@ -24,14 +25,20 @@ class StringKnowledge(Knowledge):
encoding(str): encoding
loader(Any): loader
"""
super().__init__(
knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
self._text = text
self._type = knowledge_type
self._loader = loader
self._encoding = encoding
def _load(self) -> List[Document]:
"""Load raw text from loader."""
metadata = {"source": "raw text"}
if self._metadata:
metadata.update(self._metadata) # type: ignore
docs = [Document(content=self._text, metadata=metadata)]
return docs

View File

@ -1,5 +1,5 @@
"""TXT Knowledge."""
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union
import chardet
@ -20,6 +20,7 @@ class TXTKnowledge(Knowledge):
file_path: Optional[str] = None,
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
loader: Optional[Any] = None,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
**kwargs: Any,
) -> None:
"""Create TXT Knowledge with Knowledge arguments.
@ -29,9 +30,13 @@ class TXTKnowledge(Knowledge):
knowledge_type(KnowledgeType, optional): knowledge type
loader(Any, optional): loader
"""
self._path = file_path
self._type = knowledge_type
self._loader = loader
super().__init__(
path=file_path,
knowledge_type=knowledge_type,
data_loader=loader,
metadata=metadata,
**kwargs,
)
def _load(self) -> List[Document]:
"""Load txt document from loader."""
@ -48,6 +53,8 @@ class TXTKnowledge(Knowledge):
else:
text = raw_text.decode(result["encoding"])
metadata = {"source": self._path}
if self._metadata:
metadata.update(self._metadata) # type: ignore
return [Document(content=text, metadata=metadata)]
return [Document.langchain2doc(lc_document) for lc_document in documents]

View File

@ -26,9 +26,9 @@ class URLKnowledge(Knowledge):
encoding(str, optional): csv encoding
loader(Any, optional): loader
"""
self._path = url or None
self._type = knowledge_type
self._loader = loader
super().__init__(
path=url, knowledge_type=knowledge_type, loader=loader, **kwargs
)
self._encoding = encoding
self._source_column = source_column

View File

@ -1,9 +1,10 @@
"""Base retriever module."""
from abc import ABC, abstractmethod
from enum import Enum
from typing import List
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.storage.vector_store.filters import MetadataFilters
class RetrieverStrategy(str, Enum):
@ -23,83 +24,112 @@ class RetrieverStrategy(str, Enum):
class BaseRetriever(ABC):
"""Base retriever."""
def retrieve(self, query: str) -> List[Chunk]:
def retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
query (str): query text.
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
return self._retrieve(query)
return self._retrieve(query, filters)
async def aretrieve(self, query: str) -> List[Chunk]:
async def aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): async query text
query (str): async query text.
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
return await self._aretrieve(query)
return await self._aretrieve(query, filters)
def retrieve_with_scores(self, query: str, score_threshold: float) -> List[Chunk]:
def retrieve_with_scores(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
query (str): query text.
score_threshold (float): score threshold.
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
return self._retrieve_with_score(query, score_threshold)
return self._retrieve_with_score(query, score_threshold, filters)
async def aretrieve_with_scores(
self, query: str, score_threshold: float
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
return await self._aretrieve_with_score(query, score_threshold)
return await self._aretrieve_with_score(query, score_threshold, filters)
@abstractmethod
def _retrieve(self, query: str) -> List[Chunk]:
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
@abstractmethod
async def _aretrieve(self, query: str) -> List[Chunk]:
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Async Retrieve knowledge chunks.
Args:
query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
@abstractmethod
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
@ -107,13 +137,17 @@ class BaseRetriever(ABC):
@abstractmethod
async def _aretrieve_with_score(
self, query: str, score_threshold: float
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Async Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks

View File

@ -8,6 +8,7 @@ from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
@ -93,11 +94,14 @@ class DBSchemaRetriever(BaseRetriever):
self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k)
def _retrieve(self, query: str) -> List[Chunk]:
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: metadata filters.
Returns:
List[Chunk]: list of chunks
@ -105,7 +109,7 @@ class DBSchemaRetriever(BaseRetriever):
if self._need_embeddings:
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k)
self._vector_store_connector.similar_search(query, self._top_k, filters)
for query in queries
]
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
@ -115,30 +119,39 @@ class DBSchemaRetriever(BaseRetriever):
table_summaries = _parse_db_summary(self._connector)
return [Chunk(content=table_summary) for table_summary in table_summaries]
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
Returns:
List[Chunk]: list of chunks
"""
return self._retrieve(query)
return self._retrieve(query, filters)
async def _aretrieve(self, query: str) -> List[Chunk]:
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: metadata filters.
Returns:
List[Chunk]: list of chunks
"""
if self._need_embeddings:
queries = [query]
candidates = [self._similarity_search(query) for query in queries]
candidates = [self._similarity_search(query, filters) for query in queries]
result_candidates = await run_async_tasks(
tasks=candidates, concurrency_limit=1
)
@ -154,22 +167,25 @@ class DBSchemaRetriever(BaseRetriever):
return [Chunk(content=table_summary) for table_summary in table_summaries]
async def _aretrieve_with_score(
self, query: str, score_threshold: float
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
"""
return await self._aretrieve(query)
return await self._aretrieve(query, filters)
async def _similarity_search(self, query) -> List[Chunk]:
async def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search."""
return self._vector_store_connector.similar_search(
query,
self._top_k,
)
return self._vector_store_connector.similar_search(query, self._top_k, filters)
async def _aparse_db_summary(self) -> List[str]:
"""Similar search."""

View File

@ -7,6 +7,7 @@ from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.util.tracer import root_tracer
@ -65,37 +66,44 @@ class EmbeddingRetriever(BaseRetriever):
self._vector_store_connector = vector_store_connector
self._rerank = rerank or DefaultRanker(self._top_k)
def _retrieve(self, query: str) -> List[Chunk]:
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: metadata filters.
Return:
List[Chunk]: list of chunks
"""
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k)
self._vector_store_connector.similar_search(query, self._top_k, filters)
for query in queries
]
res_candidates = cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
return res_candidates
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
queries = [query]
candidates_with_score = [
self._vector_store_connector.similar_search_with_scores(
query, self._top_k, score_threshold
query, self._top_k, score_threshold, filters
)
for query in queries
]
@ -105,37 +113,44 @@ class EmbeddingRetriever(BaseRetriever):
new_candidates_with_score = self._rerank.rank(new_candidates_with_score)
return new_candidates_with_score
async def _aretrieve(self, query: str) -> List[Chunk]:
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
query (str): query text.
filters: metadata filters.
Return:
List[Chunk]: list of chunks
"""
queries = [query]
if self._query_rewrite:
candidates_tasks = [self._similarity_search(query) for query in queries]
candidates_tasks = [
self._similarity_search(query, filters) for query in queries
]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
new_queries = await self._query_rewrite.rewrite(
origin_query=query, context=context, nums=1
)
queries.extend(new_queries)
candidates = [self._similarity_search(query) for query in queries]
candidates = [self._similarity_search(query, filters) for query in queries]
new_candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
return new_candidates
async def _aretrieve_with_score(
self, query: str, score_threshold: float
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
@ -145,7 +160,9 @@ class EmbeddingRetriever(BaseRetriever):
"EmbeddingRetriever.query_rewrite.similarity_search",
metadata={"query": query, "score_threshold": score_threshold},
):
candidates_tasks = [self._similarity_search(query) for query in queries]
candidates_tasks = [
self._similarity_search(query, filters) for query in queries
]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
with root_tracer.start_span(
@ -162,7 +179,7 @@ class EmbeddingRetriever(BaseRetriever):
metadata={"query": query, "score_threshold": score_threshold},
):
candidates_with_score = [
self._similarity_search_with_score(query, score_threshold)
self._similarity_search_with_score(query, score_threshold, filters)
for query in queries
]
res_candidates_with_score = await run_async_tasks(
@ -183,12 +200,11 @@ class EmbeddingRetriever(BaseRetriever):
new_candidates_with_score = self._rerank.rank(new_candidates_with_score)
return new_candidates_with_score
async def _similarity_search(self, query) -> List[Chunk]:
async def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search."""
return self._vector_store_connector.similar_search(
query,
self._top_k,
)
return self._vector_store_connector.similar_search(query, self._top_k, filters)
async def _run_async_tasks(self, tasks) -> List[Chunk]:
"""Run async tasks."""
@ -197,9 +213,9 @@ class EmbeddingRetriever(BaseRetriever):
return cast(List[Chunk], candidates)
async def _similarity_search_with_score(
self, query, score_threshold
self, query, score_threshold, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search with score."""
return self._vector_store_connector.similar_search_with_scores(
query, self._top_k, score_threshold
query, self._top_k, score_threshold, filters
)

View File

@ -36,5 +36,4 @@ def test_retrieve(query, top_k, mock_vector_store_connector, embedding_retriever
retrieved_chunks = embedding_retriever._retrieve(query)
mock_vector_store_connector.similar_search.assert_called_once_with(query, top_k)
assert len(retrieved_chunks) == top_k

View File

@ -457,12 +457,12 @@ class MarkdownHeaderTextSplitter(TextSplitter):
# Output line-by-line or aggregated into chunks w/ common headers
if headers_to_split_on is None:
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
("####", "Header 4"),
("#####", "Header 5"),
("######", "Header 6"),
("#", "Header1"),
("##", "Header2"),
("###", "Header3"),
("####", "Header4"),
("#####", "Header5"),
("######", "Header6"),
]
if filters is None:
filters = []

View File

@ -7,6 +7,7 @@ from dbgpt.core import Chunk
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
CFG = Config()
@ -46,52 +47,81 @@ class KnowledgeSpaceRetriever(BaseRetriever):
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
def _retrieve(self, query: str) -> List[Chunk]:
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
query (str): query text.
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks
"""
candidates = self._vector_store_connector.similar_search(
doc=query, topk=self._top_k
doc=query, topk=self._top_k, filters=filters
)
return candidates
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
candidates_with_score = self._vector_store_connector.similar_search_with_scores(
doc=query, topk=self._top_k, score_threshold=score_threshold
doc=query,
topk=self._top_k,
score_threshold=score_threshold,
filters=filters,
)
return candidates_with_score
async def _aretrieve(self, query: str) -> List[Chunk]:
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
query (str): query text.
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks
"""
candidates = await blocking_func_to_async(self._executor, self._retrieve, query)
candidates = await blocking_func_to_async(
self._executor, self._retrieve, query, filters
)
return candidates
async def _aretrieve_with_score(
self, query: str, score_threshold: float
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
query (str): query text.
score_threshold (float): score threshold.
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks with score
List[Chunk]: list of chunks with score.
"""
candidates_with_score = await blocking_func_to_async(
self._executor, self._retrieve_with_score, query, score_threshold
self._executor, self._retrieve_with_score, query, score_threshold, filters
)
return candidates_with_score

View File

@ -4,11 +4,12 @@ import math
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
from typing import Any, List, Optional
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.i18n_utils import _
logger = logging.getLogger(__name__)
@ -176,13 +177,15 @@ class VectorStoreBase(ABC):
return ids
@abstractmethod
def similar_search(self, text: str, topk: int) -> List[Chunk]:
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search in vector database.
Args:
text(str): The query text.
topk(int): The number of similar documents to return.
filters(Optional[MetadataFilters]): metadata filters.
Return:
List[Chunk]: The similar documents.
"""
@ -190,7 +193,11 @@ class VectorStoreBase(ABC):
@abstractmethod
def similar_search_with_scores(
self, text, topk, score_threshold: float
self,
text,
topk,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Similar search with scores in vector database.
@ -199,6 +206,7 @@ class VectorStoreBase(ABC):
topk(int): The number of similar documents to return.
score_threshold(int): score_threshold: Optional, a floating point value
between 0 to 1
filters(Optional[MetadataFilters]): metadata filters.
Return:
List[Chunk]: The similar documents.
"""
@ -223,6 +231,15 @@ class VectorStoreBase(ABC):
Args:
vector_name(str): The name of vector to delete.
"""
pass
def convert_metadata_filters(self, filters: MetadataFilters) -> Any:
"""Convert metadata filters to vector store filters.
Args:
filters: (Optional[MetadataFilters]) metadata filters.
"""
raise NotImplementedError
def _normalization_vectors(self, vectors):
"""Return L2-normalization vectors to scale[0,1].

View File

@ -1,7 +1,7 @@
"""Chroma vector store."""
import logging
import os
from typing import Any, List
from typing import List, Optional
from chromadb import PersistentClient
from chromadb.config import Settings
@ -13,6 +13,7 @@ from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.util.i18n_utils import _
from .base import _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig
from .filters import FilterOperator, MetadataFilters
logger = logging.getLogger(__name__)
@ -86,16 +87,23 @@ class ChromaStore(VectorStoreBase):
collection_metadata=collection_metadata,
)
def similar_search(self, text, topk, **kwargs: Any) -> List[Chunk]:
def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Search similar documents."""
logger.info("ChromaStore similar search")
lc_documents = self.vector_store_client.similarity_search(text, topk, **kwargs)
where_filters = self.convert_metadata_filters(filters) if filters else None
lc_documents = self.vector_store_client.similarity_search(
text, topk, filter=where_filters
)
return [
Chunk(content=doc.page_content, metadata=doc.metadata)
for doc in lc_documents
]
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]:
def similar_search_with_scores(
self, text, topk, score_threshold, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Search similar documents with scores.
Chroma similar_search_with_score.
@ -106,11 +114,16 @@ class ChromaStore(VectorStoreBase):
score_threshold(float): score_threshold: Optional, a floating point value
between 0 to 1 to filter the resulting set of retrieved docs,0 is
dissimilar, 1 is most similar.
filters(MetadataFilters): metadata filters, defaults to None
"""
logger.info("ChromaStore similar search with scores")
where_filters = self.convert_metadata_filters(filters) if filters else None
docs_and_scores = (
self.vector_store_client.similarity_search_with_relevance_scores(
query=text, k=topk, score_threshold=score_threshold
query=text,
k=topk,
score_threshold=score_threshold,
filter=where_filters,
)
)
return [
@ -152,10 +165,71 @@ class ChromaStore(VectorStoreBase):
collection = self.vector_store_client._collection
collection.delete(ids=ids)
def convert_metadata_filters(
self,
filters: MetadataFilters,
) -> dict:
"""Convert metadata filters to Chroma filters.
Args:
filters(MetadataFilters): metadata filters.
Returns:
dict: Chroma filters.
"""
where_filters = {}
filters_list = []
condition = filters.condition
chroma_condition = f"${condition}"
if filters.filters:
for filter in filters.filters:
if filter.operator:
filters_list.append(
{
filter.key: {
_convert_chroma_filter_operator(
filter.operator
): filter.value
}
}
)
else:
filters_list.append({filter.key: filter.value}) # type: ignore
if len(filters_list) == 1:
return filters_list[0]
elif len(filters_list) > 1:
where_filters[chroma_condition] = filters_list
return where_filters
def _clean_persist_folder(self):
"""Clean persist folder."""
for root, dirs, files in os.walk(self.persist_dir, topdown=False):
for name in files:
os.remove(os.path.join(root, name))
for name in dirs:
os.rmdir(os.path.join(root, name))
os.rmdir(self.persist_dir)
def _convert_chroma_filter_operator(operator: str) -> str:
"""Convert operator to Chroma where operator.
Args:
operator(str): operator.
Returns:
str: Chroma where operator.
"""
if operator == FilterOperator.EQ:
return "$eq"
elif operator == FilterOperator.NE:
return "$ne"
elif operator == FilterOperator.GT:
return "$gt"
elif operator == FilterOperator.LT:
return "$lt"
elif operator == FilterOperator.GTE:
return "$gte"
elif operator == FilterOperator.LTE:
return "$lte"
else:
raise ValueError(f"Chroma Where operator {operator} not supported")

View File

@ -13,6 +13,7 @@ from dbgpt.core.awel.flow import (
)
from dbgpt.storage import vector_store
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.i18n_utils import _
connector: Dict[str, Type] = {}
@ -128,23 +129,29 @@ class VectorStoreConnector:
max_threads,
)
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
def similar_search(
self, doc: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search in vector database.
Args:
- doc: query text
- topk: topk
- filters: metadata filters.
Return:
- chunks: chunks.
"""
return self.client.similar_search(doc, topk)
return self.client.similar_search(doc, topk, filters)
def similar_search_with_scores(
self, doc: str, topk: int, score_threshold: float
self,
doc: str,
topk: int,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Similar search with scores in vector database.
"""Similar_search_with_score in vector database.
similar_search_with_score in vector database..
Return docs and relevance scores in the range [0, 1].
Args:
@ -153,10 +160,13 @@ class VectorStoreConnector:
score_threshold(float): score_threshold: Optional, a floating point value
between 0 to 1 to filter the resulting set of retrieved docs,0 is
dissimilar, 1 is most similar.
filters: metadata filters.
Return:
- chunks: chunks.
- chunks: Return docs and relevance scores in the range [0, 1].
"""
return self.client.similar_search_with_scores(doc, topk, score_threshold)
return self.client.similar_search_with_scores(
doc, topk, score_threshold, filters
)
@property
def vector_store_config(self) -> VectorStoreConfig:

View File

@ -0,0 +1,56 @@
"""Vector Store Meta data filters."""
from enum import Enum
from typing import List, Union
from pydantic import BaseModel, Field
class FilterOperator(str, Enum):
"""Meta data filter operator."""
EQ = "=="
GT = ">"
LT = "<"
NE = "!="
GTE = ">="
LTE = "<="
IN = "in"
NIN = "nin"
EXISTS = "exists"
class FilterCondition(str, Enum):
"""Vector Store Meta data filter conditions."""
AND = "and"
OR = "or"
class MetadataFilter(BaseModel):
"""Meta data filter."""
key: str = Field(
...,
description="The key of metadata to filter.",
)
operator: FilterOperator = Field(
default=FilterOperator.EQ,
description="The operator of metadata filter.",
)
value: Union[str, int, float, List[str], List[int], List[float]] = Field(
...,
description="The value of metadata to filter.",
)
class MetadataFilters(BaseModel):
"""Meta data filters."""
condition: FilterCondition = Field(
default=FilterCondition.AND,
description="The condition of metadata filters.",
)
filters: List[MetadataFilter] = Field(
...,
description="The metadata filters.",
)

View File

@ -14,6 +14,7 @@ from dbgpt.storage.vector_store.base import (
VectorStoreBase,
VectorStoreConfig,
)
from dbgpt.storage.vector_store.filters import FilterOperator, MetadataFilters
from dbgpt.util import string_utils
from dbgpt.util.i18n_utils import _
@ -206,6 +207,7 @@ class MilvusStore(VectorStoreBase):
self.vector_field = milvus_vector_config.get("embedding_field") or "vector"
self.text_field = milvus_vector_config.get("text_field") or "content"
self.metadata_field = milvus_vector_config.get("metadata_field") or "metadata"
self.props_field = milvus_vector_config.get("props_field") or "props_field"
if (self.username is None) != (self.password is None):
raise ValueError(
@ -284,6 +286,7 @@ class MilvusStore(VectorStoreBase):
vector_field = self.vector_field
text_field = self.text_field
metadata_field = self.metadata_field
props_field = self.props_field
# self.text_field = text_field
collection_name = vector_name
fields = []
@ -300,6 +303,7 @@ class MilvusStore(VectorStoreBase):
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
fields.append(FieldSchema(props_field, DataType.JSON))
schema = CollectionSchema(fields)
# Create the collection
collection = Collection(collection_name, schema)
@ -346,6 +350,7 @@ class MilvusStore(VectorStoreBase):
for d in metadatas:
# for key, value in d.items():
insert_dict.setdefault("metadata", []).append(json.dumps(d))
insert_dict.setdefault("props_field", []).append(d)
# Convert dict to list of lists for insertion
insert_list = [insert_dict[x] for x in self.fields]
# Insert into the collection.
@ -368,7 +373,9 @@ class MilvusStore(VectorStoreBase):
doc_ids = [str(doc_id) for doc_id in doc_ids]
return doc_ids
def similar_search(self, text, topk) -> List[Chunk]:
def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Perform a search on a query string and return results."""
from pymilvus import Collection, DataType
@ -383,7 +390,9 @@ class MilvusStore(VectorStoreBase):
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
_, docs_and_scores = self._search(text, topk)
# convert to milvus expr filter.
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
_, docs_and_scores = self._search(text, topk, expr=milvus_filter_expr)
return [
Chunk(
@ -393,7 +402,13 @@ class MilvusStore(VectorStoreBase):
for doc, _, _ in docs_and_scores
]
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]:
def similar_search_with_scores(
self,
text: str,
topk: int,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
@ -401,15 +416,10 @@ class MilvusStore(VectorStoreBase):
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
embedding (List[float]): The embedding vector being searched.
k (int, optional): The amount of results to return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
text (str): The query text.
topk (int): The number of similar documents to return.
score_threshold (float): Optional, a floating point value between 0 to 1.
filters (Optional[MetadataFilters]): Optional, metadata filters.
Returns:
List[Tuple[Document, float]]: Result doc and score.
"""
@ -427,7 +437,11 @@ class MilvusStore(VectorStoreBase):
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
_, docs_and_scores = self._search(text, topk)
# convert to milvus expr filter.
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
_, docs_and_scores = self._search(
query=text, topk=topk, expr=milvus_filter_expr
)
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
logger.warning(
"similarity score need between" f" 0 and 1, got {docs_and_scores}"
@ -462,6 +476,20 @@ class MilvusStore(VectorStoreBase):
timeout: Optional[int] = None,
**kwargs: Any,
):
"""Search in vector database.
Args:
query: query text.
k: topk.
param: search params.
expr: search expr.
partition_names: partition names.
round_decimal: round decimal.
timeout: timeout.
**kwargs: kwargs.
Returns:
Tuple[Document, float, int]: Result doc and score.
"""
self.col.load()
# use default index params.
if param is None:
@ -495,7 +523,9 @@ class MilvusStore(VectorStoreBase):
result.id,
)
)
if len(ret) == 0:
logger.warning("No relevant docs were retrieved.")
return None, []
return ret[0], ret
def vector_name_exists(self):
@ -523,6 +553,40 @@ class MilvusStore(VectorStoreBase):
logger.info(f"begin delete milvus ids: {ids}")
delete_ids = ids.split(",")
doc_ids = [int(doc_id) for doc_id in delete_ids]
delet_expr = f"{self.primary_field} in {doc_ids}"
self.col.delete(delet_expr)
delete_expr = f"{self.primary_field} in {doc_ids}"
self.col.delete(delete_expr)
return True
def convert_metadata_filters(self, filters: MetadataFilters) -> str:
"""Convert filter to milvus filters.
Args:
- filters: metadata filters.
Returns:
- metadata_filters: metadata filters.
"""
metadata_filters = []
for metadata_filter in filters.filters:
if isinstance(metadata_filter.value, str):
expr = (
f"{self.props_field}['{metadata_filter.key}'] "
f"{FilterOperator.EQ} '{metadata_filter.value}'"
)
metadata_filters.append(expr)
elif isinstance(metadata_filter.value, List):
expr = (
f"{self.props_field}['{metadata_filter.key}'] "
f"{FilterOperator.IN} {metadata_filter.value}"
)
metadata_filters.append(expr)
else:
expr = (
f"{self.props_field}['{metadata_filter.key}'] "
f"{FilterOperator.EQ} {str(metadata_filter.value)}"
)
metadata_filters.append(expr)
if len(metadata_filters) > 1:
metadata_filter_expr = f" {filters.condition} ".join(metadata_filters)
else:
metadata_filter_expr = metadata_filters[0]
return metadata_filter_expr

View File

@ -1,6 +1,6 @@
"""Postgres vector store."""
import logging
from typing import Any, List
from typing import List, Optional
from dbgpt._private.pydantic import Field
from dbgpt.core import Chunk
@ -10,6 +10,7 @@ from dbgpt.storage.vector_store.base import (
VectorStoreBase,
VectorStoreConfig,
)
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.i18n_utils import _
logger = logging.getLogger(__name__)
@ -70,9 +71,11 @@ class PGVectorStore(VectorStoreBase):
connection_string=self.connection_string,
)
def similar_search(self, text: str, topk: int, **kwargs: Any) -> List[Chunk]:
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Perform similar search in PGVector."""
return self.vector_store_client.similarity_search(text, topk)
return self.vector_store_client.similarity_search(text, topk, filters)
def vector_name_exists(self) -> bool:
"""Check if vector name exists."""

View File

@ -1,7 +1,7 @@
"""Weaviate vector store."""
import logging
import os
from typing import List
from typing import List, Optional
from dbgpt._private.pydantic import Field
from dbgpt.core import Chunk
@ -9,6 +9,7 @@ from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.util.i18n_utils import _
from .base import _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig
from .filters import MetadataFilters
logger = logging.getLogger(__name__)
@ -80,7 +81,9 @@ class WeaviateStore(VectorStoreBase):
self.vector_store_client = weaviate.Client(self.weaviate_url)
def similar_search(self, text: str, topk: int) -> List[Chunk]:
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Perform similar search in Weaviate."""
logger.info("Weaviate similar search")
# nearText = {

View File

@ -0,0 +1,60 @@
"""Rag Metadata Properties filter example.
pre-requirements:
make sure you have set your embedding model path in your example code.
Examples:
..code-block:: shell
python examples/rag/metadata_filter_example.py
"""
import asyncio
import os
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
from dbgpt.rag import ChunkParameters
from dbgpt.rag.assembler import EmbeddingAssembler
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.rag.knowledge import KnowledgeFactory
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
def _create_vector_connector():
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="example_metadata_filter_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
async def main():
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
knowledge = KnowledgeFactory.from_file_path(file_path)
vector_connector = _create_vector_connector()
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_MARKDOWN_HEADER")
# get embedding assembler
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
vector_store_connector=vector_connector,
)
assembler.persist()
# get embeddings retriever
retriever = assembler.as_retriever(3)
# create metadata filter
metadata_filter = MetadataFilter(key="Header2", value="AWEL Design")
filters = MetadataFilters(filters=[metadata_filter])
chunks = await retriever.aretrieve_with_scores(
"what is awel talk about", 0.0, filters
)
print(f"embedding rag example results:{chunks}")
if __name__ == "__main__":
asyncio.run(main())