mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
feat(RAG):add metadata properties filters (#1395)
This commit is contained in:
parent
0f2b46da62
commit
37e7c0151b
@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from dbgpt.core import Document
|
||||
from dbgpt.rag.text_splitter.text_splitter import (
|
||||
@ -147,16 +147,18 @@ class Knowledge(ABC):
|
||||
self,
|
||||
path: Optional[str] = None,
|
||||
knowledge_type: Optional[KnowledgeType] = None,
|
||||
data_loader: Optional[Any] = None,
|
||||
loader: Optional[Any] = None,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Knowledge arguments."""
|
||||
self._path = path
|
||||
self._type = knowledge_type
|
||||
self._data_loader = data_loader
|
||||
self._loader = loader
|
||||
self._metadata = metadata
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load knowledge from data_loader."""
|
||||
"""Load knowledge from data loader."""
|
||||
documents = self._load()
|
||||
return self._postprocess(documents)
|
||||
|
||||
@ -171,12 +173,12 @@ class Knowledge(ABC):
|
||||
return None
|
||||
|
||||
def _postprocess(self, docs: List[Document]) -> List[Document]:
|
||||
"""Post process knowledge from data_loader."""
|
||||
"""Post process knowledge from data loader."""
|
||||
return docs
|
||||
|
||||
@abstractmethod
|
||||
def _load(self) -> List[Document]:
|
||||
"""Preprocess knowledge from data_loader."""
|
||||
"""Preprocess knowledge from data loader."""
|
||||
|
||||
@classmethod
|
||||
def support_chunk_strategy(cls) -> List[ChunkStrategy]:
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""CSV Knowledge."""
|
||||
import csv
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import Document
|
||||
from dbgpt.rag.knowledge.base import (
|
||||
@ -21,6 +21,7 @@ class CSVKnowledge(Knowledge):
|
||||
source_column: Optional[str] = None,
|
||||
encoding: Optional[str] = "utf-8",
|
||||
loader: Optional[Any] = None,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create CSV Knowledge with Knowledge arguments.
|
||||
@ -32,9 +33,13 @@ class CSVKnowledge(Knowledge):
|
||||
encoding(str, optional): csv encoding
|
||||
loader(Any, optional): loader
|
||||
"""
|
||||
self._path = file_path
|
||||
self._type = knowledge_type
|
||||
self._loader = loader
|
||||
super().__init__(
|
||||
path=file_path,
|
||||
knowledge_type=knowledge_type,
|
||||
data_loader=loader,
|
||||
metadata=metadata,
|
||||
**kwargs,
|
||||
)
|
||||
self._encoding = encoding
|
||||
self._source_column = source_column
|
||||
|
||||
@ -67,6 +72,8 @@ class CSVKnowledge(Knowledge):
|
||||
f"file."
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
doc = Document(content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Datasource Knowledge."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import Document
|
||||
from dbgpt.datasource import BaseConnector
|
||||
@ -16,26 +16,29 @@ class DatasourceKnowledge(Knowledge):
|
||||
connector: BaseConnector,
|
||||
summary_template: str = "{table_name}({columns})",
|
||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create Datasource Knowledge with Knowledge arguments.
|
||||
|
||||
Args:
|
||||
path(str, optional): file path
|
||||
connector(BaseConnector): connector
|
||||
summary_template(str, optional): summary template
|
||||
knowledge_type(KnowledgeType, optional): knowledge type
|
||||
data_loader(Any, optional): loader
|
||||
metadata(Dict[str, Union[str, List[str]], optional): metadata
|
||||
"""
|
||||
self._connector = connector
|
||||
self._summary_template = summary_template
|
||||
super().__init__(knowledge_type=knowledge_type, **kwargs)
|
||||
super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs)
|
||||
|
||||
def _load(self) -> List[Document]:
|
||||
"""Load datasource document from data_loader."""
|
||||
docs = []
|
||||
for table_summary in _parse_db_summary(self._connector, self._summary_template):
|
||||
docs.append(
|
||||
Document(content=table_summary, metadata={"source": "database"})
|
||||
)
|
||||
metadata = {"source": "database"}
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
docs.append(Document(content=table_summary, metadata=metadata))
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Docx Knowledge."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import docx
|
||||
|
||||
@ -21,6 +21,7 @@ class DocxKnowledge(Knowledge):
|
||||
knowledge_type: Any = KnowledgeType.DOCUMENT,
|
||||
encoding: Optional[str] = "utf-8",
|
||||
loader: Optional[Any] = None,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create Docx Knowledge with Knowledge arguments.
|
||||
@ -31,9 +32,13 @@ class DocxKnowledge(Knowledge):
|
||||
encoding(str, optional): csv encoding
|
||||
loader(Any, optional): loader
|
||||
"""
|
||||
self._path = file_path
|
||||
self._type = knowledge_type
|
||||
self._loader = loader
|
||||
super().__init__(
|
||||
path=file_path,
|
||||
knowledge_type=knowledge_type,
|
||||
data_loader=loader,
|
||||
metadata=metadata,
|
||||
**kwargs,
|
||||
)
|
||||
self._encoding = encoding
|
||||
|
||||
def _load(self) -> List[Document]:
|
||||
@ -48,9 +53,10 @@ class DocxKnowledge(Knowledge):
|
||||
para = doc.paragraphs[i]
|
||||
text = para.text
|
||||
content.append(text)
|
||||
docs.append(
|
||||
Document(content="\n".join(content), metadata={"source": self._path})
|
||||
)
|
||||
metadata = {"source": self._path}
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
docs.append(Document(content="\n".join(content), metadata=metadata))
|
||||
return docs
|
||||
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Knowledge Factory to create knowledge from file path and url."""
|
||||
from typing import List, Optional, Type
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
|
||||
from dbgpt.rag.knowledge.string import StringKnowledge
|
||||
@ -13,6 +13,7 @@ class KnowledgeFactory:
|
||||
self,
|
||||
file_path: Optional[str] = None,
|
||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
):
|
||||
"""Create Knowledge Factory with file path and knowledge type.
|
||||
|
||||
@ -22,18 +23,21 @@ class KnowledgeFactory:
|
||||
"""
|
||||
self._file_path = file_path
|
||||
self._knowledge_type = knowledge_type
|
||||
self._metadata = metadata
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
datasource: str = "",
|
||||
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
):
|
||||
"""Create knowledge from file path, url or text.
|
||||
|
||||
Args:
|
||||
datasource: path of the file to convert
|
||||
knowledge_type: type of knowledge
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]]
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@ -52,12 +56,16 @@ class KnowledgeFactory:
|
||||
match knowledge_type:
|
||||
case KnowledgeType.DOCUMENT:
|
||||
return cls.from_file_path(
|
||||
file_path=datasource, knowledge_type=knowledge_type
|
||||
file_path=datasource,
|
||||
knowledge_type=knowledge_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
case KnowledgeType.URL:
|
||||
return cls.from_url(url=datasource, knowledge_type=knowledge_type)
|
||||
case KnowledgeType.TEXT:
|
||||
return cls.from_text(text=datasource, knowledge_type=knowledge_type)
|
||||
return cls.from_text(
|
||||
text=datasource, knowledge_type=knowledge_type, metadata=metadata
|
||||
)
|
||||
case _:
|
||||
raise Exception(f"Unsupported knowledge type '{knowledge_type}'")
|
||||
|
||||
@ -66,6 +74,7 @@ class KnowledgeFactory:
|
||||
cls,
|
||||
file_path: str = "",
|
||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
) -> Knowledge:
|
||||
"""Create knowledge from path.
|
||||
|
||||
@ -82,10 +91,11 @@ class KnowledgeFactory:
|
||||
datasource="path/to/document.pdf",
|
||||
knowledge_type=KnowledgeType.DOCUMENT,
|
||||
)
|
||||
|
||||
"""
|
||||
factory = cls(file_path=file_path, knowledge_type=knowledge_type)
|
||||
return factory._select_document_knowledge(
|
||||
file_path=file_path, knowledge_type=knowledge_type
|
||||
file_path=file_path, knowledge_type=knowledge_type, metadata=metadata
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -117,6 +127,7 @@ class KnowledgeFactory:
|
||||
def from_text(
|
||||
text: str = "",
|
||||
knowledge_type: KnowledgeType = KnowledgeType.TEXT,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
) -> Knowledge:
|
||||
"""Create knowledge from text.
|
||||
|
||||
@ -127,6 +138,7 @@ class KnowledgeFactory:
|
||||
return StringKnowledge(
|
||||
text=text,
|
||||
knowledge_type=knowledge_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _select_document_knowledge(self, **kwargs):
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""HTML Knowledge."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import chardet
|
||||
|
||||
@ -20,6 +20,7 @@ class HTMLKnowledge(Knowledge):
|
||||
file_path: Optional[str] = None,
|
||||
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
||||
loader: Optional[Any] = None,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create HTML Knowledge with Knowledge arguments.
|
||||
@ -29,9 +30,13 @@ class HTMLKnowledge(Knowledge):
|
||||
knowledge_type(KnowledgeType, optional): knowledge type
|
||||
loader(Any, optional): loader
|
||||
"""
|
||||
self._path = file_path
|
||||
self._type = knowledge_type
|
||||
self._loader = loader
|
||||
super().__init__(
|
||||
path=file_path,
|
||||
knowledge_type=knowledge_type,
|
||||
data_loader=loader,
|
||||
metadata=metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _load(self) -> List[Document]:
|
||||
"""Load html document from loader."""
|
||||
@ -48,6 +53,8 @@ class HTMLKnowledge(Knowledge):
|
||||
else:
|
||||
text = raw_text.decode(result["encoding"])
|
||||
metadata = {"source": self._path}
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
return [Document(content=text, metadata=metadata)]
|
||||
|
||||
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Markdown Knowledge."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import Document
|
||||
from dbgpt.rag.knowledge.base import (
|
||||
@ -19,6 +19,7 @@ class MarkdownKnowledge(Knowledge):
|
||||
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
||||
encoding: Optional[str] = "utf-8",
|
||||
loader: Optional[Any] = None,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create Markdown Knowledge with Knowledge arguments.
|
||||
@ -29,9 +30,13 @@ class MarkdownKnowledge(Knowledge):
|
||||
encoding(str, optional): csv encoding
|
||||
loader(Any, optional): loader
|
||||
"""
|
||||
self._path = file_path
|
||||
self._type = knowledge_type
|
||||
self._loader = loader
|
||||
super().__init__(
|
||||
path=file_path,
|
||||
knowledge_type=knowledge_type,
|
||||
data_loader=loader,
|
||||
metadata=metadata,
|
||||
**kwargs,
|
||||
)
|
||||
self._encoding = encoding
|
||||
|
||||
def _load(self) -> List[Document]:
|
||||
@ -44,6 +49,8 @@ class MarkdownKnowledge(Knowledge):
|
||||
with open(self._path, encoding=self._encoding, errors="ignore") as f:
|
||||
markdown_text = f.read()
|
||||
metadata = {"source": self._path}
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
documents = [Document(content=markdown_text, metadata=metadata)]
|
||||
return documents
|
||||
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""PDF Knowledge."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import Document
|
||||
from dbgpt.rag.knowledge.base import (
|
||||
@ -19,6 +19,7 @@ class PDFKnowledge(Knowledge):
|
||||
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
||||
loader: Optional[Any] = None,
|
||||
language: Optional[str] = "zh",
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create PDF Knowledge with Knowledge arguments.
|
||||
@ -29,9 +30,13 @@ class PDFKnowledge(Knowledge):
|
||||
loader(Any, optional): loader
|
||||
language(str, optional): language
|
||||
"""
|
||||
self._path = file_path
|
||||
self._type = knowledge_type
|
||||
self._loader = loader
|
||||
super().__init__(
|
||||
path=file_path,
|
||||
knowledge_type=knowledge_type,
|
||||
data_loader=loader,
|
||||
metadata=metadata,
|
||||
**kwargs,
|
||||
)
|
||||
self._language = language
|
||||
|
||||
def _load(self) -> List[Document]:
|
||||
@ -65,6 +70,8 @@ class PDFKnowledge(Knowledge):
|
||||
page = "\n".join(cleaned_lines)
|
||||
# cleaned_pages.append(page)
|
||||
metadata = {"source": self._path, "page": page_num}
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
# text = "\f".join(cleaned_pages)
|
||||
document = Document(content=page, metadata=metadata)
|
||||
documents.append(document)
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""PPTX Knowledge."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import Document
|
||||
from dbgpt.rag.knowledge.base import (
|
||||
@ -19,6 +19,7 @@ class PPTXKnowledge(Knowledge):
|
||||
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
||||
loader: Optional[Any] = None,
|
||||
language: Optional[str] = "zh",
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create PPTX knowledge with PDF Knowledge arguments.
|
||||
@ -28,9 +29,13 @@ class PPTXKnowledge(Knowledge):
|
||||
knowledge_type:(KnowledgeType) knowledge type
|
||||
loader:(Optional[Any]) loader
|
||||
"""
|
||||
self._path = file_path
|
||||
self._type = knowledge_type
|
||||
self._loader = loader
|
||||
super().__init__(
|
||||
path=file_path,
|
||||
knowledge_type=knowledge_type,
|
||||
data_loader=loader,
|
||||
metadata=metadata,
|
||||
**kwargs,
|
||||
)
|
||||
self._language = language
|
||||
|
||||
def _load(self) -> List[Document]:
|
||||
@ -47,9 +52,10 @@ class PPTXKnowledge(Knowledge):
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text") and shape.text:
|
||||
content += shape.text
|
||||
docs.append(
|
||||
Document(content=content, metadata={"source": slide.slide_id})
|
||||
)
|
||||
metadata = {"source": self._path}
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
docs.append(Document(content=content, metadata=metadata))
|
||||
return docs
|
||||
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""String Knowledge."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import Document
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType
|
||||
@ -14,6 +14,7 @@ class StringKnowledge(Knowledge):
|
||||
knowledge_type: KnowledgeType = KnowledgeType.TEXT,
|
||||
encoding: Optional[str] = "utf-8",
|
||||
loader: Optional[Any] = None,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create String knowledge parameters.
|
||||
@ -24,14 +25,20 @@ class StringKnowledge(Knowledge):
|
||||
encoding(str): encoding
|
||||
loader(Any): loader
|
||||
"""
|
||||
super().__init__(
|
||||
knowledge_type=knowledge_type,
|
||||
data_loader=loader,
|
||||
metadata=metadata,
|
||||
**kwargs,
|
||||
)
|
||||
self._text = text
|
||||
self._type = knowledge_type
|
||||
self._loader = loader
|
||||
self._encoding = encoding
|
||||
|
||||
def _load(self) -> List[Document]:
|
||||
"""Load raw text from loader."""
|
||||
metadata = {"source": "raw text"}
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
docs = [Document(content=self._text, metadata=metadata)]
|
||||
return docs
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""TXT Knowledge."""
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import chardet
|
||||
|
||||
@ -20,6 +20,7 @@ class TXTKnowledge(Knowledge):
|
||||
file_path: Optional[str] = None,
|
||||
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
||||
loader: Optional[Any] = None,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create TXT Knowledge with Knowledge arguments.
|
||||
@ -29,9 +30,13 @@ class TXTKnowledge(Knowledge):
|
||||
knowledge_type(KnowledgeType, optional): knowledge type
|
||||
loader(Any, optional): loader
|
||||
"""
|
||||
self._path = file_path
|
||||
self._type = knowledge_type
|
||||
self._loader = loader
|
||||
super().__init__(
|
||||
path=file_path,
|
||||
knowledge_type=knowledge_type,
|
||||
data_loader=loader,
|
||||
metadata=metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _load(self) -> List[Document]:
|
||||
"""Load txt document from loader."""
|
||||
@ -48,6 +53,8 @@ class TXTKnowledge(Knowledge):
|
||||
else:
|
||||
text = raw_text.decode(result["encoding"])
|
||||
metadata = {"source": self._path}
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
return [Document(content=text, metadata=metadata)]
|
||||
|
||||
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
||||
|
@ -26,9 +26,9 @@ class URLKnowledge(Knowledge):
|
||||
encoding(str, optional): csv encoding
|
||||
loader(Any, optional): loader
|
||||
"""
|
||||
self._path = url or None
|
||||
self._type = knowledge_type
|
||||
self._loader = loader
|
||||
super().__init__(
|
||||
path=url, knowledge_type=knowledge_type, loader=loader, **kwargs
|
||||
)
|
||||
self._encoding = encoding
|
||||
self._source_column = source_column
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
"""Base retriever module."""
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
|
||||
|
||||
class RetrieverStrategy(str, Enum):
|
||||
@ -23,83 +24,112 @@ class RetrieverStrategy(str, Enum):
|
||||
class BaseRetriever(ABC):
|
||||
"""Base retriever."""
|
||||
|
||||
def retrieve(self, query: str) -> List[Chunk]:
|
||||
def retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
query (str): query text.
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
return self._retrieve(query)
|
||||
return self._retrieve(query, filters)
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Chunk]:
|
||||
async def aretrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): async query text
|
||||
query (str): async query text.
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
return await self._aretrieve(query)
|
||||
return await self._aretrieve(query, filters)
|
||||
|
||||
def retrieve_with_scores(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
def retrieve_with_scores(
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
query (str): query text.
|
||||
score_threshold (float): score threshold.
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
return self._retrieve_with_score(query, score_threshold)
|
||||
return self._retrieve_with_score(query, score_threshold, filters)
|
||||
|
||||
async def aretrieve_with_scores(
|
||||
self, query: str, score_threshold: float
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
return await self._aretrieve_with_score(query, score_threshold)
|
||||
return await self._aretrieve_with_score(query, score_threshold, filters)
|
||||
|
||||
@abstractmethod
|
||||
def _retrieve(self, query: str) -> List[Chunk]:
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
||||
async def _aretrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Async Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
def _retrieve_with_score(
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
@ -107,13 +137,17 @@ class BaseRetriever(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def _aretrieve_with_score(
|
||||
self, query: str, score_threshold: float
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Async Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
|
@ -8,6 +8,7 @@ from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
|
||||
|
||||
@ -93,11 +94,14 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
self._need_embeddings = True
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
|
||||
def _retrieve(self, query: str) -> List[Chunk]:
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
filters: metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
@ -105,7 +109,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._vector_store_connector.similar_search(query, self._top_k)
|
||||
self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||
for query in queries
|
||||
]
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
@ -115,30 +119,39 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
table_summaries = _parse_db_summary(self._connector)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
def _retrieve_with_score(
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
filters: metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
return self._retrieve(query)
|
||||
return self._retrieve(query, filters)
|
||||
|
||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
||||
async def _aretrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
filters: metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [self._similarity_search(query) for query in queries]
|
||||
candidates = [self._similarity_search(query, filters) for query in queries]
|
||||
result_candidates = await run_async_tasks(
|
||||
tasks=candidates, concurrency_limit=1
|
||||
)
|
||||
@ -154,22 +167,25 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
self, query: str, score_threshold: float
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
filters: metadata filters.
|
||||
"""
|
||||
return await self._aretrieve(query)
|
||||
return await self._aretrieve(query, filters)
|
||||
|
||||
async def _similarity_search(self, query) -> List[Chunk]:
|
||||
async def _similarity_search(
|
||||
self, query, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
return self._vector_store_connector.similar_search(
|
||||
query,
|
||||
self._top_k,
|
||||
)
|
||||
return self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||
|
||||
async def _aparse_db_summary(self) -> List[str]:
|
||||
"""Similar search."""
|
||||
|
@ -7,6 +7,7 @@ from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
@ -65,37 +66,44 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
|
||||
def _retrieve(self, query: str) -> List[Chunk]:
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._vector_store_connector.similar_search(query, self._top_k)
|
||||
self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||
for query in queries
|
||||
]
|
||||
res_candidates = cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
return res_candidates
|
||||
|
||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
def _retrieve_with_score(
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score
|
||||
"""
|
||||
queries = [query]
|
||||
candidates_with_score = [
|
||||
self._vector_store_connector.similar_search_with_scores(
|
||||
query, self._top_k, score_threshold
|
||||
query, self._top_k, score_threshold, filters
|
||||
)
|
||||
for query in queries
|
||||
]
|
||||
@ -105,37 +113,44 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
new_candidates_with_score = self._rerank.rank(new_candidates_with_score)
|
||||
return new_candidates_with_score
|
||||
|
||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
||||
async def _aretrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
|
||||
query (str): query text.
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
queries = [query]
|
||||
if self._query_rewrite:
|
||||
candidates_tasks = [self._similarity_search(query) for query in queries]
|
||||
candidates_tasks = [
|
||||
self._similarity_search(query, filters) for query in queries
|
||||
]
|
||||
chunks = await self._run_async_tasks(candidates_tasks)
|
||||
context = "\n".join([chunk.content for chunk in chunks])
|
||||
new_queries = await self._query_rewrite.rewrite(
|
||||
origin_query=query, context=context, nums=1
|
||||
)
|
||||
queries.extend(new_queries)
|
||||
candidates = [self._similarity_search(query) for query in queries]
|
||||
candidates = [self._similarity_search(query, filters) for query in queries]
|
||||
new_candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
|
||||
return new_candidates
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
self, query: str, score_threshold: float
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score
|
||||
"""
|
||||
@ -145,7 +160,9 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
"EmbeddingRetriever.query_rewrite.similarity_search",
|
||||
metadata={"query": query, "score_threshold": score_threshold},
|
||||
):
|
||||
candidates_tasks = [self._similarity_search(query) for query in queries]
|
||||
candidates_tasks = [
|
||||
self._similarity_search(query, filters) for query in queries
|
||||
]
|
||||
chunks = await self._run_async_tasks(candidates_tasks)
|
||||
context = "\n".join([chunk.content for chunk in chunks])
|
||||
with root_tracer.start_span(
|
||||
@ -162,7 +179,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
metadata={"query": query, "score_threshold": score_threshold},
|
||||
):
|
||||
candidates_with_score = [
|
||||
self._similarity_search_with_score(query, score_threshold)
|
||||
self._similarity_search_with_score(query, score_threshold, filters)
|
||||
for query in queries
|
||||
]
|
||||
res_candidates_with_score = await run_async_tasks(
|
||||
@ -183,12 +200,11 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
new_candidates_with_score = self._rerank.rank(new_candidates_with_score)
|
||||
return new_candidates_with_score
|
||||
|
||||
async def _similarity_search(self, query) -> List[Chunk]:
|
||||
async def _similarity_search(
|
||||
self, query, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
return self._vector_store_connector.similar_search(
|
||||
query,
|
||||
self._top_k,
|
||||
)
|
||||
return self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||
|
||||
async def _run_async_tasks(self, tasks) -> List[Chunk]:
|
||||
"""Run async tasks."""
|
||||
@ -197,9 +213,9 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
return cast(List[Chunk], candidates)
|
||||
|
||||
async def _similarity_search_with_score(
|
||||
self, query, score_threshold
|
||||
self, query, score_threshold, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search with score."""
|
||||
return self._vector_store_connector.similar_search_with_scores(
|
||||
query, self._top_k, score_threshold
|
||||
query, self._top_k, score_threshold, filters
|
||||
)
|
||||
|
@ -36,5 +36,4 @@ def test_retrieve(query, top_k, mock_vector_store_connector, embedding_retriever
|
||||
|
||||
retrieved_chunks = embedding_retriever._retrieve(query)
|
||||
|
||||
mock_vector_store_connector.similar_search.assert_called_once_with(query, top_k)
|
||||
assert len(retrieved_chunks) == top_k
|
||||
|
@ -457,12 +457,12 @@ class MarkdownHeaderTextSplitter(TextSplitter):
|
||||
# Output line-by-line or aggregated into chunks w/ common headers
|
||||
if headers_to_split_on is None:
|
||||
headers_to_split_on = [
|
||||
("#", "Header 1"),
|
||||
("##", "Header 2"),
|
||||
("###", "Header 3"),
|
||||
("####", "Header 4"),
|
||||
("#####", "Header 5"),
|
||||
("######", "Header 6"),
|
||||
("#", "Header1"),
|
||||
("##", "Header2"),
|
||||
("###", "Header3"),
|
||||
("####", "Header4"),
|
||||
("#####", "Header5"),
|
||||
("######", "Header6"),
|
||||
]
|
||||
if filters is None:
|
||||
filters = []
|
||||
|
@ -7,6 +7,7 @@ from dbgpt.core import Chunk
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
|
||||
CFG = Config()
|
||||
@ -46,52 +47,81 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
def _retrieve(self, query: str) -> List[Chunk]:
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
query (str): query text.
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
candidates = self._vector_store_connector.similar_search(
|
||||
doc=query, topk=self._top_k
|
||||
doc=query, topk=self._top_k, filters=filters
|
||||
)
|
||||
return candidates
|
||||
|
||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
def _retrieve_with_score(
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score
|
||||
"""
|
||||
candidates_with_score = self._vector_store_connector.similar_search_with_scores(
|
||||
doc=query, topk=self._top_k, score_threshold=score_threshold
|
||||
doc=query,
|
||||
topk=self._top_k,
|
||||
score_threshold=score_threshold,
|
||||
filters=filters,
|
||||
)
|
||||
return candidates_with_score
|
||||
|
||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
||||
async def _aretrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
query (str): query text.
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
candidates = await blocking_func_to_async(self._executor, self._retrieve, query)
|
||||
candidates = await blocking_func_to_async(
|
||||
self._executor, self._retrieve, query, filters
|
||||
)
|
||||
return candidates
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
self, query: str, score_threshold: float
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
query (str): query text.
|
||||
score_threshold (float): score threshold.
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score
|
||||
List[Chunk]: list of chunks with score.
|
||||
"""
|
||||
candidates_with_score = await blocking_func_to_async(
|
||||
self._executor, self._retrieve_with_score, query, score_threshold
|
||||
self._executor, self._retrieve_with_score, query, score_threshold, filters
|
||||
)
|
||||
return candidates_with_score
|
||||
|
@ -4,11 +4,12 @@ import math
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -176,13 +177,15 @@ class VectorStoreBase(ABC):
|
||||
return ids
|
||||
|
||||
@abstractmethod
|
||||
def similar_search(self, text: str, topk: int) -> List[Chunk]:
|
||||
def similar_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search in vector database.
|
||||
|
||||
Args:
|
||||
text(str): The query text.
|
||||
topk(int): The number of similar documents to return.
|
||||
|
||||
filters(Optional[MetadataFilters]): metadata filters.
|
||||
Return:
|
||||
List[Chunk]: The similar documents.
|
||||
"""
|
||||
@ -190,7 +193,11 @@ class VectorStoreBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def similar_search_with_scores(
|
||||
self, text, topk, score_threshold: float
|
||||
self,
|
||||
text,
|
||||
topk,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Similar search with scores in vector database.
|
||||
|
||||
@ -199,6 +206,7 @@ class VectorStoreBase(ABC):
|
||||
topk(int): The number of similar documents to return.
|
||||
score_threshold(int): score_threshold: Optional, a floating point value
|
||||
between 0 to 1
|
||||
filters(Optional[MetadataFilters]): metadata filters.
|
||||
Return:
|
||||
List[Chunk]: The similar documents.
|
||||
"""
|
||||
@ -223,6 +231,15 @@ class VectorStoreBase(ABC):
|
||||
Args:
|
||||
vector_name(str): The name of vector to delete.
|
||||
"""
|
||||
pass
|
||||
|
||||
def convert_metadata_filters(self, filters: MetadataFilters) -> Any:
|
||||
"""Convert metadata filters to vector store filters.
|
||||
|
||||
Args:
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _normalization_vectors(self, vectors):
|
||||
"""Return L2-normalization vectors to scale[0,1].
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Chroma vector store."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List
|
||||
from typing import List, Optional
|
||||
|
||||
from chromadb import PersistentClient
|
||||
from chromadb.config import Settings
|
||||
@ -13,6 +13,7 @@ from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from .base import _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig
|
||||
from .filters import FilterOperator, MetadataFilters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -86,16 +87,23 @@ class ChromaStore(VectorStoreBase):
|
||||
collection_metadata=collection_metadata,
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk, **kwargs: Any) -> List[Chunk]:
|
||||
def similar_search(
|
||||
self, text, topk, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Search similar documents."""
|
||||
logger.info("ChromaStore similar search")
|
||||
lc_documents = self.vector_store_client.similarity_search(text, topk, **kwargs)
|
||||
where_filters = self.convert_metadata_filters(filters) if filters else None
|
||||
lc_documents = self.vector_store_client.similarity_search(
|
||||
text, topk, filter=where_filters
|
||||
)
|
||||
return [
|
||||
Chunk(content=doc.page_content, metadata=doc.metadata)
|
||||
for doc in lc_documents
|
||||
]
|
||||
|
||||
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]:
|
||||
def similar_search_with_scores(
|
||||
self, text, topk, score_threshold, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Search similar documents with scores.
|
||||
|
||||
Chroma similar_search_with_score.
|
||||
@ -106,11 +114,16 @@ class ChromaStore(VectorStoreBase):
|
||||
score_threshold(float): score_threshold: Optional, a floating point value
|
||||
between 0 to 1 to filter the resulting set of retrieved docs,0 is
|
||||
dissimilar, 1 is most similar.
|
||||
filters(MetadataFilters): metadata filters, defaults to None
|
||||
"""
|
||||
logger.info("ChromaStore similar search with scores")
|
||||
where_filters = self.convert_metadata_filters(filters) if filters else None
|
||||
docs_and_scores = (
|
||||
self.vector_store_client.similarity_search_with_relevance_scores(
|
||||
query=text, k=topk, score_threshold=score_threshold
|
||||
query=text,
|
||||
k=topk,
|
||||
score_threshold=score_threshold,
|
||||
filter=where_filters,
|
||||
)
|
||||
)
|
||||
return [
|
||||
@ -152,10 +165,71 @@ class ChromaStore(VectorStoreBase):
|
||||
collection = self.vector_store_client._collection
|
||||
collection.delete(ids=ids)
|
||||
|
||||
def convert_metadata_filters(
|
||||
self,
|
||||
filters: MetadataFilters,
|
||||
) -> dict:
|
||||
"""Convert metadata filters to Chroma filters.
|
||||
|
||||
Args:
|
||||
filters(MetadataFilters): metadata filters.
|
||||
Returns:
|
||||
dict: Chroma filters.
|
||||
"""
|
||||
where_filters = {}
|
||||
filters_list = []
|
||||
condition = filters.condition
|
||||
chroma_condition = f"${condition}"
|
||||
if filters.filters:
|
||||
for filter in filters.filters:
|
||||
if filter.operator:
|
||||
filters_list.append(
|
||||
{
|
||||
filter.key: {
|
||||
_convert_chroma_filter_operator(
|
||||
filter.operator
|
||||
): filter.value
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
filters_list.append({filter.key: filter.value}) # type: ignore
|
||||
|
||||
if len(filters_list) == 1:
|
||||
return filters_list[0]
|
||||
elif len(filters_list) > 1:
|
||||
where_filters[chroma_condition] = filters_list
|
||||
return where_filters
|
||||
|
||||
def _clean_persist_folder(self):
|
||||
"""Clean persist folder."""
|
||||
for root, dirs, files in os.walk(self.persist_dir, topdown=False):
|
||||
for name in files:
|
||||
os.remove(os.path.join(root, name))
|
||||
for name in dirs:
|
||||
os.rmdir(os.path.join(root, name))
|
||||
os.rmdir(self.persist_dir)
|
||||
|
||||
|
||||
def _convert_chroma_filter_operator(operator: str) -> str:
|
||||
"""Convert operator to Chroma where operator.
|
||||
|
||||
Args:
|
||||
operator(str): operator.
|
||||
Returns:
|
||||
str: Chroma where operator.
|
||||
"""
|
||||
if operator == FilterOperator.EQ:
|
||||
return "$eq"
|
||||
elif operator == FilterOperator.NE:
|
||||
return "$ne"
|
||||
elif operator == FilterOperator.GT:
|
||||
return "$gt"
|
||||
elif operator == FilterOperator.LT:
|
||||
return "$lt"
|
||||
elif operator == FilterOperator.GTE:
|
||||
return "$gte"
|
||||
elif operator == FilterOperator.LTE:
|
||||
return "$lte"
|
||||
else:
|
||||
raise ValueError(f"Chroma Where operator {operator} not supported")
|
||||
|
@ -13,6 +13,7 @@ from dbgpt.core.awel.flow import (
|
||||
)
|
||||
from dbgpt.storage import vector_store
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
connector: Dict[str, Type] = {}
|
||||
@ -128,23 +129,29 @@ class VectorStoreConnector:
|
||||
max_threads,
|
||||
)
|
||||
|
||||
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
|
||||
def similar_search(
|
||||
self, doc: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search in vector database.
|
||||
|
||||
Args:
|
||||
- doc: query text
|
||||
- topk: topk
|
||||
- filters: metadata filters.
|
||||
Return:
|
||||
- chunks: chunks.
|
||||
"""
|
||||
return self.client.similar_search(doc, topk)
|
||||
return self.client.similar_search(doc, topk, filters)
|
||||
|
||||
def similar_search_with_scores(
|
||||
self, doc: str, topk: int, score_threshold: float
|
||||
self,
|
||||
doc: str,
|
||||
topk: int,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Similar search with scores in vector database.
|
||||
"""Similar_search_with_score in vector database.
|
||||
|
||||
similar_search_with_score in vector database..
|
||||
Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
Args:
|
||||
@ -153,10 +160,13 @@ class VectorStoreConnector:
|
||||
score_threshold(float): score_threshold: Optional, a floating point value
|
||||
between 0 to 1 to filter the resulting set of retrieved docs,0 is
|
||||
dissimilar, 1 is most similar.
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
- chunks: chunks.
|
||||
- chunks: Return docs and relevance scores in the range [0, 1].
|
||||
"""
|
||||
return self.client.similar_search_with_scores(doc, topk, score_threshold)
|
||||
return self.client.similar_search_with_scores(
|
||||
doc, topk, score_threshold, filters
|
||||
)
|
||||
|
||||
@property
|
||||
def vector_store_config(self) -> VectorStoreConfig:
|
||||
|
56
dbgpt/storage/vector_store/filters.py
Normal file
56
dbgpt/storage/vector_store/filters.py
Normal file
@ -0,0 +1,56 @@
|
||||
"""Vector Store Meta data filters."""
|
||||
from enum import Enum
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FilterOperator(str, Enum):
|
||||
"""Meta data filter operator."""
|
||||
|
||||
EQ = "=="
|
||||
GT = ">"
|
||||
LT = "<"
|
||||
NE = "!="
|
||||
GTE = ">="
|
||||
LTE = "<="
|
||||
IN = "in"
|
||||
NIN = "nin"
|
||||
EXISTS = "exists"
|
||||
|
||||
|
||||
class FilterCondition(str, Enum):
|
||||
"""Vector Store Meta data filter conditions."""
|
||||
|
||||
AND = "and"
|
||||
OR = "or"
|
||||
|
||||
|
||||
class MetadataFilter(BaseModel):
|
||||
"""Meta data filter."""
|
||||
|
||||
key: str = Field(
|
||||
...,
|
||||
description="The key of metadata to filter.",
|
||||
)
|
||||
operator: FilterOperator = Field(
|
||||
default=FilterOperator.EQ,
|
||||
description="The operator of metadata filter.",
|
||||
)
|
||||
value: Union[str, int, float, List[str], List[int], List[float]] = Field(
|
||||
...,
|
||||
description="The value of metadata to filter.",
|
||||
)
|
||||
|
||||
|
||||
class MetadataFilters(BaseModel):
|
||||
"""Meta data filters."""
|
||||
|
||||
condition: FilterCondition = Field(
|
||||
default=FilterCondition.AND,
|
||||
description="The condition of metadata filters.",
|
||||
)
|
||||
filters: List[MetadataFilter] = Field(
|
||||
...,
|
||||
description="The metadata filters.",
|
||||
)
|
@ -14,6 +14,7 @@ from dbgpt.storage.vector_store.base import (
|
||||
VectorStoreBase,
|
||||
VectorStoreConfig,
|
||||
)
|
||||
from dbgpt.storage.vector_store.filters import FilterOperator, MetadataFilters
|
||||
from dbgpt.util import string_utils
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
@ -206,6 +207,7 @@ class MilvusStore(VectorStoreBase):
|
||||
self.vector_field = milvus_vector_config.get("embedding_field") or "vector"
|
||||
self.text_field = milvus_vector_config.get("text_field") or "content"
|
||||
self.metadata_field = milvus_vector_config.get("metadata_field") or "metadata"
|
||||
self.props_field = milvus_vector_config.get("props_field") or "props_field"
|
||||
|
||||
if (self.username is None) != (self.password is None):
|
||||
raise ValueError(
|
||||
@ -284,6 +286,7 @@ class MilvusStore(VectorStoreBase):
|
||||
vector_field = self.vector_field
|
||||
text_field = self.text_field
|
||||
metadata_field = self.metadata_field
|
||||
props_field = self.props_field
|
||||
# self.text_field = text_field
|
||||
collection_name = vector_name
|
||||
fields = []
|
||||
@ -300,6 +303,7 @@ class MilvusStore(VectorStoreBase):
|
||||
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
||||
|
||||
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
|
||||
fields.append(FieldSchema(props_field, DataType.JSON))
|
||||
schema = CollectionSchema(fields)
|
||||
# Create the collection
|
||||
collection = Collection(collection_name, schema)
|
||||
@ -346,6 +350,7 @@ class MilvusStore(VectorStoreBase):
|
||||
for d in metadatas:
|
||||
# for key, value in d.items():
|
||||
insert_dict.setdefault("metadata", []).append(json.dumps(d))
|
||||
insert_dict.setdefault("props_field", []).append(d)
|
||||
# Convert dict to list of lists for insertion
|
||||
insert_list = [insert_dict[x] for x in self.fields]
|
||||
# Insert into the collection.
|
||||
@ -368,7 +373,9 @@ class MilvusStore(VectorStoreBase):
|
||||
doc_ids = [str(doc_id) for doc_id in doc_ids]
|
||||
return doc_ids
|
||||
|
||||
def similar_search(self, text, topk) -> List[Chunk]:
|
||||
def similar_search(
|
||||
self, text, topk, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Perform a search on a query string and return results."""
|
||||
from pymilvus import Collection, DataType
|
||||
|
||||
@ -383,7 +390,9 @@ class MilvusStore(VectorStoreBase):
|
||||
self.primary_field = x.name
|
||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||
self.vector_field = x.name
|
||||
_, docs_and_scores = self._search(text, topk)
|
||||
# convert to milvus expr filter.
|
||||
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
|
||||
_, docs_and_scores = self._search(text, topk, expr=milvus_filter_expr)
|
||||
|
||||
return [
|
||||
Chunk(
|
||||
@ -393,7 +402,13 @@ class MilvusStore(VectorStoreBase):
|
||||
for doc, _, _ in docs_and_scores
|
||||
]
|
||||
|
||||
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]:
|
||||
def similar_search_with_scores(
|
||||
self,
|
||||
text: str,
|
||||
topk: int,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Perform a search on a query string and return results with score.
|
||||
|
||||
For more information about the search parameters, take a look at the pymilvus
|
||||
@ -401,15 +416,10 @@ class MilvusStore(VectorStoreBase):
|
||||
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
|
||||
|
||||
Args:
|
||||
embedding (List[float]): The embedding vector being searched.
|
||||
k (int, optional): The amount of results to return. Defaults to 4.
|
||||
param (dict): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
text (str): The query text.
|
||||
topk (int): The number of similar documents to return.
|
||||
score_threshold (float): Optional, a floating point value between 0 to 1.
|
||||
filters (Optional[MetadataFilters]): Optional, metadata filters.
|
||||
Returns:
|
||||
List[Tuple[Document, float]]: Result doc and score.
|
||||
"""
|
||||
@ -427,7 +437,11 @@ class MilvusStore(VectorStoreBase):
|
||||
|
||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||
self.vector_field = x.name
|
||||
_, docs_and_scores = self._search(text, topk)
|
||||
# convert to milvus expr filter.
|
||||
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
|
||||
_, docs_and_scores = self._search(
|
||||
query=text, topk=topk, expr=milvus_filter_expr
|
||||
)
|
||||
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
|
||||
logger.warning(
|
||||
"similarity score need between" f" 0 and 1, got {docs_and_scores}"
|
||||
@ -462,6 +476,20 @@ class MilvusStore(VectorStoreBase):
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Search in vector database.
|
||||
|
||||
Args:
|
||||
query: query text.
|
||||
k: topk.
|
||||
param: search params.
|
||||
expr: search expr.
|
||||
partition_names: partition names.
|
||||
round_decimal: round decimal.
|
||||
timeout: timeout.
|
||||
**kwargs: kwargs.
|
||||
Returns:
|
||||
Tuple[Document, float, int]: Result doc and score.
|
||||
"""
|
||||
self.col.load()
|
||||
# use default index params.
|
||||
if param is None:
|
||||
@ -495,7 +523,9 @@ class MilvusStore(VectorStoreBase):
|
||||
result.id,
|
||||
)
|
||||
)
|
||||
|
||||
if len(ret) == 0:
|
||||
logger.warning("No relevant docs were retrieved.")
|
||||
return None, []
|
||||
return ret[0], ret
|
||||
|
||||
def vector_name_exists(self):
|
||||
@ -523,6 +553,40 @@ class MilvusStore(VectorStoreBase):
|
||||
logger.info(f"begin delete milvus ids: {ids}")
|
||||
delete_ids = ids.split(",")
|
||||
doc_ids = [int(doc_id) for doc_id in delete_ids]
|
||||
delet_expr = f"{self.primary_field} in {doc_ids}"
|
||||
self.col.delete(delet_expr)
|
||||
delete_expr = f"{self.primary_field} in {doc_ids}"
|
||||
self.col.delete(delete_expr)
|
||||
return True
|
||||
|
||||
def convert_metadata_filters(self, filters: MetadataFilters) -> str:
|
||||
"""Convert filter to milvus filters.
|
||||
|
||||
Args:
|
||||
- filters: metadata filters.
|
||||
Returns:
|
||||
- metadata_filters: metadata filters.
|
||||
"""
|
||||
metadata_filters = []
|
||||
for metadata_filter in filters.filters:
|
||||
if isinstance(metadata_filter.value, str):
|
||||
expr = (
|
||||
f"{self.props_field}['{metadata_filter.key}'] "
|
||||
f"{FilterOperator.EQ} '{metadata_filter.value}'"
|
||||
)
|
||||
metadata_filters.append(expr)
|
||||
elif isinstance(metadata_filter.value, List):
|
||||
expr = (
|
||||
f"{self.props_field}['{metadata_filter.key}'] "
|
||||
f"{FilterOperator.IN} {metadata_filter.value}"
|
||||
)
|
||||
metadata_filters.append(expr)
|
||||
else:
|
||||
expr = (
|
||||
f"{self.props_field}['{metadata_filter.key}'] "
|
||||
f"{FilterOperator.EQ} {str(metadata_filter.value)}"
|
||||
)
|
||||
metadata_filters.append(expr)
|
||||
if len(metadata_filters) > 1:
|
||||
metadata_filter_expr = f" {filters.condition} ".join(metadata_filters)
|
||||
else:
|
||||
metadata_filter_expr = metadata_filters[0]
|
||||
return metadata_filter_expr
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Postgres vector store."""
|
||||
import logging
|
||||
from typing import Any, List
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt.core import Chunk
|
||||
@ -10,6 +10,7 @@ from dbgpt.storage.vector_store.base import (
|
||||
VectorStoreBase,
|
||||
VectorStoreConfig,
|
||||
)
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -70,9 +71,11 @@ class PGVectorStore(VectorStoreBase):
|
||||
connection_string=self.connection_string,
|
||||
)
|
||||
|
||||
def similar_search(self, text: str, topk: int, **kwargs: Any) -> List[Chunk]:
|
||||
def similar_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Perform similar search in PGVector."""
|
||||
return self.vector_store_client.similarity_search(text, topk)
|
||||
return self.vector_store_client.similarity_search(text, topk, filters)
|
||||
|
||||
def vector_name_exists(self) -> bool:
|
||||
"""Check if vector name exists."""
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Weaviate vector store."""
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt.core import Chunk
|
||||
@ -9,6 +9,7 @@ from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from .base import _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig
|
||||
from .filters import MetadataFilters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -80,7 +81,9 @@ class WeaviateStore(VectorStoreBase):
|
||||
|
||||
self.vector_store_client = weaviate.Client(self.weaviate_url)
|
||||
|
||||
def similar_search(self, text: str, topk: int) -> List[Chunk]:
|
||||
def similar_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Perform similar search in Weaviate."""
|
||||
logger.info("Weaviate similar search")
|
||||
# nearText = {
|
||||
|
60
examples/rag/metadata_filter_example.py
Normal file
60
examples/rag/metadata_filter_example.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""Rag Metadata Properties filter example.
|
||||
pre-requirements:
|
||||
make sure you have set your embedding model path in your example code.
|
||||
|
||||
Examples:
|
||||
..code-block:: shell
|
||||
python examples/rag/metadata_filter_example.py
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
|
||||
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="example_metadata_filter_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create(),
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
vector_connector = _create_vector_connector()
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_MARKDOWN_HEADER")
|
||||
# get embedding assembler
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
# get embeddings retriever
|
||||
retriever = assembler.as_retriever(3)
|
||||
# create metadata filter
|
||||
metadata_filter = MetadataFilter(key="Header2", value="AWEL Design")
|
||||
filters = MetadataFilters(filters=[metadata_filter])
|
||||
chunks = await retriever.aretrieve_with_scores(
|
||||
"what is awel talk about", 0.0, filters
|
||||
)
|
||||
print(f"embedding rag example results:{chunks}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Loading…
Reference in New Issue
Block a user