mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 10:54:29 +00:00
feat(RAG):add metadata properties filters (#1395)
This commit is contained in:
parent
0f2b46da62
commit
37e7c0151b
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, Optional, Tuple, Type
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from dbgpt.core import Document
|
from dbgpt.core import Document
|
||||||
from dbgpt.rag.text_splitter.text_splitter import (
|
from dbgpt.rag.text_splitter.text_splitter import (
|
||||||
@ -147,16 +147,18 @@ class Knowledge(ABC):
|
|||||||
self,
|
self,
|
||||||
path: Optional[str] = None,
|
path: Optional[str] = None,
|
||||||
knowledge_type: Optional[KnowledgeType] = None,
|
knowledge_type: Optional[KnowledgeType] = None,
|
||||||
data_loader: Optional[Any] = None,
|
loader: Optional[Any] = None,
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize with Knowledge arguments."""
|
"""Initialize with Knowledge arguments."""
|
||||||
self._path = path
|
self._path = path
|
||||||
self._type = knowledge_type
|
self._type = knowledge_type
|
||||||
self._data_loader = data_loader
|
self._loader = loader
|
||||||
|
self._metadata = metadata
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> List[Document]:
|
||||||
"""Load knowledge from data_loader."""
|
"""Load knowledge from data loader."""
|
||||||
documents = self._load()
|
documents = self._load()
|
||||||
return self._postprocess(documents)
|
return self._postprocess(documents)
|
||||||
|
|
||||||
@ -171,12 +173,12 @@ class Knowledge(ABC):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _postprocess(self, docs: List[Document]) -> List[Document]:
|
def _postprocess(self, docs: List[Document]) -> List[Document]:
|
||||||
"""Post process knowledge from data_loader."""
|
"""Post process knowledge from data loader."""
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _load(self) -> List[Document]:
|
def _load(self) -> List[Document]:
|
||||||
"""Preprocess knowledge from data_loader."""
|
"""Preprocess knowledge from data loader."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def support_chunk_strategy(cls) -> List[ChunkStrategy]:
|
def support_chunk_strategy(cls) -> List[ChunkStrategy]:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""CSV Knowledge."""
|
"""CSV Knowledge."""
|
||||||
import csv
|
import csv
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from dbgpt.core import Document
|
from dbgpt.core import Document
|
||||||
from dbgpt.rag.knowledge.base import (
|
from dbgpt.rag.knowledge.base import (
|
||||||
@ -21,6 +21,7 @@ class CSVKnowledge(Knowledge):
|
|||||||
source_column: Optional[str] = None,
|
source_column: Optional[str] = None,
|
||||||
encoding: Optional[str] = "utf-8",
|
encoding: Optional[str] = "utf-8",
|
||||||
loader: Optional[Any] = None,
|
loader: Optional[Any] = None,
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create CSV Knowledge with Knowledge arguments.
|
"""Create CSV Knowledge with Knowledge arguments.
|
||||||
@ -32,9 +33,13 @@ class CSVKnowledge(Knowledge):
|
|||||||
encoding(str, optional): csv encoding
|
encoding(str, optional): csv encoding
|
||||||
loader(Any, optional): loader
|
loader(Any, optional): loader
|
||||||
"""
|
"""
|
||||||
self._path = file_path
|
super().__init__(
|
||||||
self._type = knowledge_type
|
path=file_path,
|
||||||
self._loader = loader
|
knowledge_type=knowledge_type,
|
||||||
|
data_loader=loader,
|
||||||
|
metadata=metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
self._encoding = encoding
|
self._encoding = encoding
|
||||||
self._source_column = source_column
|
self._source_column = source_column
|
||||||
|
|
||||||
@ -67,6 +72,8 @@ class CSVKnowledge(Knowledge):
|
|||||||
f"file."
|
f"file."
|
||||||
)
|
)
|
||||||
metadata = {"source": source, "row": i}
|
metadata = {"source": source, "row": i}
|
||||||
|
if self._metadata:
|
||||||
|
metadata.update(self._metadata) # type: ignore
|
||||||
doc = Document(content=content, metadata=metadata)
|
doc = Document(content=content, metadata=metadata)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Datasource Knowledge."""
|
"""Datasource Knowledge."""
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from dbgpt.core import Document
|
from dbgpt.core import Document
|
||||||
from dbgpt.datasource import BaseConnector
|
from dbgpt.datasource import BaseConnector
|
||||||
@ -16,26 +16,29 @@ class DatasourceKnowledge(Knowledge):
|
|||||||
connector: BaseConnector,
|
connector: BaseConnector,
|
||||||
summary_template: str = "{table_name}({columns})",
|
summary_template: str = "{table_name}({columns})",
|
||||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create Datasource Knowledge with Knowledge arguments.
|
"""Create Datasource Knowledge with Knowledge arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path(str, optional): file path
|
connector(BaseConnector): connector
|
||||||
|
summary_template(str, optional): summary template
|
||||||
knowledge_type(KnowledgeType, optional): knowledge type
|
knowledge_type(KnowledgeType, optional): knowledge type
|
||||||
data_loader(Any, optional): loader
|
metadata(Dict[str, Union[str, List[str]], optional): metadata
|
||||||
"""
|
"""
|
||||||
self._connector = connector
|
self._connector = connector
|
||||||
self._summary_template = summary_template
|
self._summary_template = summary_template
|
||||||
super().__init__(knowledge_type=knowledge_type, **kwargs)
|
super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs)
|
||||||
|
|
||||||
def _load(self) -> List[Document]:
|
def _load(self) -> List[Document]:
|
||||||
"""Load datasource document from data_loader."""
|
"""Load datasource document from data_loader."""
|
||||||
docs = []
|
docs = []
|
||||||
for table_summary in _parse_db_summary(self._connector, self._summary_template):
|
for table_summary in _parse_db_summary(self._connector, self._summary_template):
|
||||||
docs.append(
|
metadata = {"source": "database"}
|
||||||
Document(content=table_summary, metadata={"source": "database"})
|
if self._metadata:
|
||||||
)
|
metadata.update(self._metadata) # type: ignore
|
||||||
|
docs.append(Document(content=table_summary, metadata=metadata))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Docx Knowledge."""
|
"""Docx Knowledge."""
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import docx
|
import docx
|
||||||
|
|
||||||
@ -21,6 +21,7 @@ class DocxKnowledge(Knowledge):
|
|||||||
knowledge_type: Any = KnowledgeType.DOCUMENT,
|
knowledge_type: Any = KnowledgeType.DOCUMENT,
|
||||||
encoding: Optional[str] = "utf-8",
|
encoding: Optional[str] = "utf-8",
|
||||||
loader: Optional[Any] = None,
|
loader: Optional[Any] = None,
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create Docx Knowledge with Knowledge arguments.
|
"""Create Docx Knowledge with Knowledge arguments.
|
||||||
@ -31,9 +32,13 @@ class DocxKnowledge(Knowledge):
|
|||||||
encoding(str, optional): csv encoding
|
encoding(str, optional): csv encoding
|
||||||
loader(Any, optional): loader
|
loader(Any, optional): loader
|
||||||
"""
|
"""
|
||||||
self._path = file_path
|
super().__init__(
|
||||||
self._type = knowledge_type
|
path=file_path,
|
||||||
self._loader = loader
|
knowledge_type=knowledge_type,
|
||||||
|
data_loader=loader,
|
||||||
|
metadata=metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
self._encoding = encoding
|
self._encoding = encoding
|
||||||
|
|
||||||
def _load(self) -> List[Document]:
|
def _load(self) -> List[Document]:
|
||||||
@ -48,9 +53,10 @@ class DocxKnowledge(Knowledge):
|
|||||||
para = doc.paragraphs[i]
|
para = doc.paragraphs[i]
|
||||||
text = para.text
|
text = para.text
|
||||||
content.append(text)
|
content.append(text)
|
||||||
docs.append(
|
metadata = {"source": self._path}
|
||||||
Document(content="\n".join(content), metadata={"source": self._path})
|
if self._metadata:
|
||||||
)
|
metadata.update(self._metadata) # type: ignore
|
||||||
|
docs.append(Document(content="\n".join(content), metadata=metadata))
|
||||||
return docs
|
return docs
|
||||||
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Knowledge Factory to create knowledge from file path and url."""
|
"""Knowledge Factory to create knowledge from file path and url."""
|
||||||
from typing import List, Optional, Type
|
from typing import Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
|
from dbgpt.rag.knowledge.base import Knowledge, KnowledgeType
|
||||||
from dbgpt.rag.knowledge.string import StringKnowledge
|
from dbgpt.rag.knowledge.string import StringKnowledge
|
||||||
@ -13,6 +13,7 @@ class KnowledgeFactory:
|
|||||||
self,
|
self,
|
||||||
file_path: Optional[str] = None,
|
file_path: Optional[str] = None,
|
||||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
):
|
):
|
||||||
"""Create Knowledge Factory with file path and knowledge type.
|
"""Create Knowledge Factory with file path and knowledge type.
|
||||||
|
|
||||||
@ -22,18 +23,21 @@ class KnowledgeFactory:
|
|||||||
"""
|
"""
|
||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
self._knowledge_type = knowledge_type
|
self._knowledge_type = knowledge_type
|
||||||
|
self._metadata = metadata
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
datasource: str = "",
|
datasource: str = "",
|
||||||
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
):
|
):
|
||||||
"""Create knowledge from file path, url or text.
|
"""Create knowledge from file path, url or text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
datasource: path of the file to convert
|
datasource: path of the file to convert
|
||||||
knowledge_type: type of knowledge
|
knowledge_type: type of knowledge
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]]
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -52,12 +56,16 @@ class KnowledgeFactory:
|
|||||||
match knowledge_type:
|
match knowledge_type:
|
||||||
case KnowledgeType.DOCUMENT:
|
case KnowledgeType.DOCUMENT:
|
||||||
return cls.from_file_path(
|
return cls.from_file_path(
|
||||||
file_path=datasource, knowledge_type=knowledge_type
|
file_path=datasource,
|
||||||
|
knowledge_type=knowledge_type,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
case KnowledgeType.URL:
|
case KnowledgeType.URL:
|
||||||
return cls.from_url(url=datasource, knowledge_type=knowledge_type)
|
return cls.from_url(url=datasource, knowledge_type=knowledge_type)
|
||||||
case KnowledgeType.TEXT:
|
case KnowledgeType.TEXT:
|
||||||
return cls.from_text(text=datasource, knowledge_type=knowledge_type)
|
return cls.from_text(
|
||||||
|
text=datasource, knowledge_type=knowledge_type, metadata=metadata
|
||||||
|
)
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"Unsupported knowledge type '{knowledge_type}'")
|
raise Exception(f"Unsupported knowledge type '{knowledge_type}'")
|
||||||
|
|
||||||
@ -66,6 +74,7 @@ class KnowledgeFactory:
|
|||||||
cls,
|
cls,
|
||||||
file_path: str = "",
|
file_path: str = "",
|
||||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
) -> Knowledge:
|
) -> Knowledge:
|
||||||
"""Create knowledge from path.
|
"""Create knowledge from path.
|
||||||
|
|
||||||
@ -82,10 +91,11 @@ class KnowledgeFactory:
|
|||||||
datasource="path/to/document.pdf",
|
datasource="path/to/document.pdf",
|
||||||
knowledge_type=KnowledgeType.DOCUMENT,
|
knowledge_type=KnowledgeType.DOCUMENT,
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
factory = cls(file_path=file_path, knowledge_type=knowledge_type)
|
factory = cls(file_path=file_path, knowledge_type=knowledge_type)
|
||||||
return factory._select_document_knowledge(
|
return factory._select_document_knowledge(
|
||||||
file_path=file_path, knowledge_type=knowledge_type
|
file_path=file_path, knowledge_type=knowledge_type, metadata=metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -117,6 +127,7 @@ class KnowledgeFactory:
|
|||||||
def from_text(
|
def from_text(
|
||||||
text: str = "",
|
text: str = "",
|
||||||
knowledge_type: KnowledgeType = KnowledgeType.TEXT,
|
knowledge_type: KnowledgeType = KnowledgeType.TEXT,
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
) -> Knowledge:
|
) -> Knowledge:
|
||||||
"""Create knowledge from text.
|
"""Create knowledge from text.
|
||||||
|
|
||||||
@ -127,6 +138,7 @@ class KnowledgeFactory:
|
|||||||
return StringKnowledge(
|
return StringKnowledge(
|
||||||
text=text,
|
text=text,
|
||||||
knowledge_type=knowledge_type,
|
knowledge_type=knowledge_type,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _select_document_knowledge(self, **kwargs):
|
def _select_document_knowledge(self, **kwargs):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""HTML Knowledge."""
|
"""HTML Knowledge."""
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import chardet
|
import chardet
|
||||||
|
|
||||||
@ -20,6 +20,7 @@ class HTMLKnowledge(Knowledge):
|
|||||||
file_path: Optional[str] = None,
|
file_path: Optional[str] = None,
|
||||||
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
||||||
loader: Optional[Any] = None,
|
loader: Optional[Any] = None,
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create HTML Knowledge with Knowledge arguments.
|
"""Create HTML Knowledge with Knowledge arguments.
|
||||||
@ -29,9 +30,13 @@ class HTMLKnowledge(Knowledge):
|
|||||||
knowledge_type(KnowledgeType, optional): knowledge type
|
knowledge_type(KnowledgeType, optional): knowledge type
|
||||||
loader(Any, optional): loader
|
loader(Any, optional): loader
|
||||||
"""
|
"""
|
||||||
self._path = file_path
|
super().__init__(
|
||||||
self._type = knowledge_type
|
path=file_path,
|
||||||
self._loader = loader
|
knowledge_type=knowledge_type,
|
||||||
|
data_loader=loader,
|
||||||
|
metadata=metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def _load(self) -> List[Document]:
|
def _load(self) -> List[Document]:
|
||||||
"""Load html document from loader."""
|
"""Load html document from loader."""
|
||||||
@ -48,6 +53,8 @@ class HTMLKnowledge(Knowledge):
|
|||||||
else:
|
else:
|
||||||
text = raw_text.decode(result["encoding"])
|
text = raw_text.decode(result["encoding"])
|
||||||
metadata = {"source": self._path}
|
metadata = {"source": self._path}
|
||||||
|
if self._metadata:
|
||||||
|
metadata.update(self._metadata) # type: ignore
|
||||||
return [Document(content=text, metadata=metadata)]
|
return [Document(content=text, metadata=metadata)]
|
||||||
|
|
||||||
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Markdown Knowledge."""
|
"""Markdown Knowledge."""
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from dbgpt.core import Document
|
from dbgpt.core import Document
|
||||||
from dbgpt.rag.knowledge.base import (
|
from dbgpt.rag.knowledge.base import (
|
||||||
@ -19,6 +19,7 @@ class MarkdownKnowledge(Knowledge):
|
|||||||
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
||||||
encoding: Optional[str] = "utf-8",
|
encoding: Optional[str] = "utf-8",
|
||||||
loader: Optional[Any] = None,
|
loader: Optional[Any] = None,
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create Markdown Knowledge with Knowledge arguments.
|
"""Create Markdown Knowledge with Knowledge arguments.
|
||||||
@ -29,9 +30,13 @@ class MarkdownKnowledge(Knowledge):
|
|||||||
encoding(str, optional): csv encoding
|
encoding(str, optional): csv encoding
|
||||||
loader(Any, optional): loader
|
loader(Any, optional): loader
|
||||||
"""
|
"""
|
||||||
self._path = file_path
|
super().__init__(
|
||||||
self._type = knowledge_type
|
path=file_path,
|
||||||
self._loader = loader
|
knowledge_type=knowledge_type,
|
||||||
|
data_loader=loader,
|
||||||
|
metadata=metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
self._encoding = encoding
|
self._encoding = encoding
|
||||||
|
|
||||||
def _load(self) -> List[Document]:
|
def _load(self) -> List[Document]:
|
||||||
@ -44,6 +49,8 @@ class MarkdownKnowledge(Knowledge):
|
|||||||
with open(self._path, encoding=self._encoding, errors="ignore") as f:
|
with open(self._path, encoding=self._encoding, errors="ignore") as f:
|
||||||
markdown_text = f.read()
|
markdown_text = f.read()
|
||||||
metadata = {"source": self._path}
|
metadata = {"source": self._path}
|
||||||
|
if self._metadata:
|
||||||
|
metadata.update(self._metadata) # type: ignore
|
||||||
documents = [Document(content=markdown_text, metadata=metadata)]
|
documents = [Document(content=markdown_text, metadata=metadata)]
|
||||||
return documents
|
return documents
|
||||||
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""PDF Knowledge."""
|
"""PDF Knowledge."""
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from dbgpt.core import Document
|
from dbgpt.core import Document
|
||||||
from dbgpt.rag.knowledge.base import (
|
from dbgpt.rag.knowledge.base import (
|
||||||
@ -19,6 +19,7 @@ class PDFKnowledge(Knowledge):
|
|||||||
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
||||||
loader: Optional[Any] = None,
|
loader: Optional[Any] = None,
|
||||||
language: Optional[str] = "zh",
|
language: Optional[str] = "zh",
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create PDF Knowledge with Knowledge arguments.
|
"""Create PDF Knowledge with Knowledge arguments.
|
||||||
@ -29,9 +30,13 @@ class PDFKnowledge(Knowledge):
|
|||||||
loader(Any, optional): loader
|
loader(Any, optional): loader
|
||||||
language(str, optional): language
|
language(str, optional): language
|
||||||
"""
|
"""
|
||||||
self._path = file_path
|
super().__init__(
|
||||||
self._type = knowledge_type
|
path=file_path,
|
||||||
self._loader = loader
|
knowledge_type=knowledge_type,
|
||||||
|
data_loader=loader,
|
||||||
|
metadata=metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
self._language = language
|
self._language = language
|
||||||
|
|
||||||
def _load(self) -> List[Document]:
|
def _load(self) -> List[Document]:
|
||||||
@ -65,6 +70,8 @@ class PDFKnowledge(Knowledge):
|
|||||||
page = "\n".join(cleaned_lines)
|
page = "\n".join(cleaned_lines)
|
||||||
# cleaned_pages.append(page)
|
# cleaned_pages.append(page)
|
||||||
metadata = {"source": self._path, "page": page_num}
|
metadata = {"source": self._path, "page": page_num}
|
||||||
|
if self._metadata:
|
||||||
|
metadata.update(self._metadata) # type: ignore
|
||||||
# text = "\f".join(cleaned_pages)
|
# text = "\f".join(cleaned_pages)
|
||||||
document = Document(content=page, metadata=metadata)
|
document = Document(content=page, metadata=metadata)
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""PPTX Knowledge."""
|
"""PPTX Knowledge."""
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from dbgpt.core import Document
|
from dbgpt.core import Document
|
||||||
from dbgpt.rag.knowledge.base import (
|
from dbgpt.rag.knowledge.base import (
|
||||||
@ -19,6 +19,7 @@ class PPTXKnowledge(Knowledge):
|
|||||||
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
||||||
loader: Optional[Any] = None,
|
loader: Optional[Any] = None,
|
||||||
language: Optional[str] = "zh",
|
language: Optional[str] = "zh",
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create PPTX knowledge with PDF Knowledge arguments.
|
"""Create PPTX knowledge with PDF Knowledge arguments.
|
||||||
@ -28,9 +29,13 @@ class PPTXKnowledge(Knowledge):
|
|||||||
knowledge_type:(KnowledgeType) knowledge type
|
knowledge_type:(KnowledgeType) knowledge type
|
||||||
loader:(Optional[Any]) loader
|
loader:(Optional[Any]) loader
|
||||||
"""
|
"""
|
||||||
self._path = file_path
|
super().__init__(
|
||||||
self._type = knowledge_type
|
path=file_path,
|
||||||
self._loader = loader
|
knowledge_type=knowledge_type,
|
||||||
|
data_loader=loader,
|
||||||
|
metadata=metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
self._language = language
|
self._language = language
|
||||||
|
|
||||||
def _load(self) -> List[Document]:
|
def _load(self) -> List[Document]:
|
||||||
@ -47,9 +52,10 @@ class PPTXKnowledge(Knowledge):
|
|||||||
for shape in slide.shapes:
|
for shape in slide.shapes:
|
||||||
if hasattr(shape, "text") and shape.text:
|
if hasattr(shape, "text") and shape.text:
|
||||||
content += shape.text
|
content += shape.text
|
||||||
docs.append(
|
metadata = {"source": self._path}
|
||||||
Document(content=content, metadata={"source": slide.slide_id})
|
if self._metadata:
|
||||||
)
|
metadata.update(self._metadata) # type: ignore
|
||||||
|
docs.append(Document(content=content, metadata=metadata))
|
||||||
return docs
|
return docs
|
||||||
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""String Knowledge."""
|
"""String Knowledge."""
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from dbgpt.core import Document
|
from dbgpt.core import Document
|
||||||
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType
|
from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge, KnowledgeType
|
||||||
@ -14,6 +14,7 @@ class StringKnowledge(Knowledge):
|
|||||||
knowledge_type: KnowledgeType = KnowledgeType.TEXT,
|
knowledge_type: KnowledgeType = KnowledgeType.TEXT,
|
||||||
encoding: Optional[str] = "utf-8",
|
encoding: Optional[str] = "utf-8",
|
||||||
loader: Optional[Any] = None,
|
loader: Optional[Any] = None,
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create String knowledge parameters.
|
"""Create String knowledge parameters.
|
||||||
@ -24,14 +25,20 @@ class StringKnowledge(Knowledge):
|
|||||||
encoding(str): encoding
|
encoding(str): encoding
|
||||||
loader(Any): loader
|
loader(Any): loader
|
||||||
"""
|
"""
|
||||||
|
super().__init__(
|
||||||
|
knowledge_type=knowledge_type,
|
||||||
|
data_loader=loader,
|
||||||
|
metadata=metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
self._text = text
|
self._text = text
|
||||||
self._type = knowledge_type
|
|
||||||
self._loader = loader
|
|
||||||
self._encoding = encoding
|
self._encoding = encoding
|
||||||
|
|
||||||
def _load(self) -> List[Document]:
|
def _load(self) -> List[Document]:
|
||||||
"""Load raw text from loader."""
|
"""Load raw text from loader."""
|
||||||
metadata = {"source": "raw text"}
|
metadata = {"source": "raw text"}
|
||||||
|
if self._metadata:
|
||||||
|
metadata.update(self._metadata) # type: ignore
|
||||||
docs = [Document(content=self._text, metadata=metadata)]
|
docs = [Document(content=self._text, metadata=metadata)]
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""TXT Knowledge."""
|
"""TXT Knowledge."""
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import chardet
|
import chardet
|
||||||
|
|
||||||
@ -20,6 +20,7 @@ class TXTKnowledge(Knowledge):
|
|||||||
file_path: Optional[str] = None,
|
file_path: Optional[str] = None,
|
||||||
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
knowledge_type: KnowledgeType = KnowledgeType.DOCUMENT,
|
||||||
loader: Optional[Any] = None,
|
loader: Optional[Any] = None,
|
||||||
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create TXT Knowledge with Knowledge arguments.
|
"""Create TXT Knowledge with Knowledge arguments.
|
||||||
@ -29,9 +30,13 @@ class TXTKnowledge(Knowledge):
|
|||||||
knowledge_type(KnowledgeType, optional): knowledge type
|
knowledge_type(KnowledgeType, optional): knowledge type
|
||||||
loader(Any, optional): loader
|
loader(Any, optional): loader
|
||||||
"""
|
"""
|
||||||
self._path = file_path
|
super().__init__(
|
||||||
self._type = knowledge_type
|
path=file_path,
|
||||||
self._loader = loader
|
knowledge_type=knowledge_type,
|
||||||
|
data_loader=loader,
|
||||||
|
metadata=metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def _load(self) -> List[Document]:
|
def _load(self) -> List[Document]:
|
||||||
"""Load txt document from loader."""
|
"""Load txt document from loader."""
|
||||||
@ -48,6 +53,8 @@ class TXTKnowledge(Knowledge):
|
|||||||
else:
|
else:
|
||||||
text = raw_text.decode(result["encoding"])
|
text = raw_text.decode(result["encoding"])
|
||||||
metadata = {"source": self._path}
|
metadata = {"source": self._path}
|
||||||
|
if self._metadata:
|
||||||
|
metadata.update(self._metadata) # type: ignore
|
||||||
return [Document(content=text, metadata=metadata)]
|
return [Document(content=text, metadata=metadata)]
|
||||||
|
|
||||||
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
||||||
|
@ -26,9 +26,9 @@ class URLKnowledge(Knowledge):
|
|||||||
encoding(str, optional): csv encoding
|
encoding(str, optional): csv encoding
|
||||||
loader(Any, optional): loader
|
loader(Any, optional): loader
|
||||||
"""
|
"""
|
||||||
self._path = url or None
|
super().__init__(
|
||||||
self._type = knowledge_type
|
path=url, knowledge_type=knowledge_type, loader=loader, **kwargs
|
||||||
self._loader = loader
|
)
|
||||||
self._encoding = encoding
|
self._encoding = encoding
|
||||||
self._source_column = source_column
|
self._source_column = source_column
|
||||||
|
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
"""Base retriever module."""
|
"""Base retriever module."""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from dbgpt.core import Chunk
|
from dbgpt.core import Chunk
|
||||||
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
|
|
||||||
|
|
||||||
class RetrieverStrategy(str, Enum):
|
class RetrieverStrategy(str, Enum):
|
||||||
@ -23,83 +24,112 @@ class RetrieverStrategy(str, Enum):
|
|||||||
class BaseRetriever(ABC):
|
class BaseRetriever(ABC):
|
||||||
"""Base retriever."""
|
"""Base retriever."""
|
||||||
|
|
||||||
def retrieve(self, query: str) -> List[Chunk]:
|
def retrieve(
|
||||||
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks.
|
"""Retrieve knowledge chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text.
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
return self._retrieve(query)
|
return self._retrieve(query, filters)
|
||||||
|
|
||||||
async def aretrieve(self, query: str) -> List[Chunk]:
|
async def aretrieve(
|
||||||
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks.
|
"""Retrieve knowledge chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): async query text
|
query (str): async query text.
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
return await self._aretrieve(query)
|
return await self._aretrieve(query, filters)
|
||||||
|
|
||||||
def retrieve_with_scores(self, query: str, score_threshold: float) -> List[Chunk]:
|
def retrieve_with_scores(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks with score.
|
"""Retrieve knowledge chunks with score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text.
|
||||||
score_threshold (float): score threshold
|
score_threshold (float): score threshold.
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
return self._retrieve_with_score(query, score_threshold)
|
return self._retrieve_with_score(query, score_threshold, filters)
|
||||||
|
|
||||||
async def aretrieve_with_scores(
|
async def aretrieve_with_scores(
|
||||||
self, query: str, score_threshold: float
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks with score.
|
"""Retrieve knowledge chunks with score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
score_threshold (float): score threshold
|
score_threshold (float): score threshold
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
return await self._aretrieve_with_score(query, score_threshold)
|
return await self._aretrieve_with_score(query, score_threshold, filters)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _retrieve(self, query: str) -> List[Chunk]:
|
def _retrieve(
|
||||||
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks.
|
"""Retrieve knowledge chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
async def _aretrieve(
|
||||||
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Async Retrieve knowledge chunks.
|
"""Async Retrieve knowledge chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
def _retrieve_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks with score.
|
"""Retrieve knowledge chunks with score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
score_threshold (float): score threshold
|
score_threshold (float): score threshold
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
@ -107,13 +137,17 @@ class BaseRetriever(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _aretrieve_with_score(
|
async def _aretrieve_with_score(
|
||||||
self, query: str, score_threshold: float
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
"""Async Retrieve knowledge chunks with score.
|
"""Async Retrieve knowledge chunks with score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
score_threshold (float): score threshold
|
score_threshold (float): score threshold
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
|
@ -8,6 +8,7 @@ from dbgpt.rag.retriever.base import BaseRetriever
|
|||||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
from dbgpt.util.chat_util import run_async_tasks
|
from dbgpt.util.chat_util import run_async_tasks
|
||||||
|
|
||||||
|
|
||||||
@ -93,11 +94,14 @@ class DBSchemaRetriever(BaseRetriever):
|
|||||||
self._need_embeddings = True
|
self._need_embeddings = True
|
||||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||||
|
|
||||||
def _retrieve(self, query: str) -> List[Chunk]:
|
def _retrieve(
|
||||||
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks.
|
"""Retrieve knowledge chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
|
filters: metadata filters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
@ -105,7 +109,7 @@ class DBSchemaRetriever(BaseRetriever):
|
|||||||
if self._need_embeddings:
|
if self._need_embeddings:
|
||||||
queries = [query]
|
queries = [query]
|
||||||
candidates = [
|
candidates = [
|
||||||
self._vector_store_connector.similar_search(query, self._top_k)
|
self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||||
for query in queries
|
for query in queries
|
||||||
]
|
]
|
||||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||||
@ -115,30 +119,39 @@ class DBSchemaRetriever(BaseRetriever):
|
|||||||
table_summaries = _parse_db_summary(self._connector)
|
table_summaries = _parse_db_summary(self._connector)
|
||||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||||
|
|
||||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
def _retrieve_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks with score.
|
"""Retrieve knowledge chunks with score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
score_threshold (float): score threshold
|
score_threshold (float): score threshold
|
||||||
|
filters: metadata filters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
return self._retrieve(query)
|
return self._retrieve(query, filters)
|
||||||
|
|
||||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
async def _aretrieve(
|
||||||
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks.
|
"""Retrieve knowledge chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
|
filters: metadata filters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
if self._need_embeddings:
|
if self._need_embeddings:
|
||||||
queries = [query]
|
queries = [query]
|
||||||
candidates = [self._similarity_search(query) for query in queries]
|
candidates = [self._similarity_search(query, filters) for query in queries]
|
||||||
result_candidates = await run_async_tasks(
|
result_candidates = await run_async_tasks(
|
||||||
tasks=candidates, concurrency_limit=1
|
tasks=candidates, concurrency_limit=1
|
||||||
)
|
)
|
||||||
@ -154,22 +167,25 @@ class DBSchemaRetriever(BaseRetriever):
|
|||||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||||
|
|
||||||
async def _aretrieve_with_score(
|
async def _aretrieve_with_score(
|
||||||
self, query: str, score_threshold: float
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks with score.
|
"""Retrieve knowledge chunks with score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
score_threshold (float): score threshold
|
score_threshold (float): score threshold
|
||||||
|
filters: metadata filters.
|
||||||
"""
|
"""
|
||||||
return await self._aretrieve(query)
|
return await self._aretrieve(query, filters)
|
||||||
|
|
||||||
async def _similarity_search(self, query) -> List[Chunk]:
|
async def _similarity_search(
|
||||||
|
self, query, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Similar search."""
|
"""Similar search."""
|
||||||
return self._vector_store_connector.similar_search(
|
return self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||||
query,
|
|
||||||
self._top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _aparse_db_summary(self) -> List[str]:
|
async def _aparse_db_summary(self) -> List[str]:
|
||||||
"""Similar search."""
|
"""Similar search."""
|
||||||
|
@ -7,6 +7,7 @@ from dbgpt.rag.retriever.base import BaseRetriever
|
|||||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
from dbgpt.util.chat_util import run_async_tasks
|
from dbgpt.util.chat_util import run_async_tasks
|
||||||
from dbgpt.util.tracer import root_tracer
|
from dbgpt.util.tracer import root_tracer
|
||||||
|
|
||||||
@ -65,37 +66,44 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
self._vector_store_connector = vector_store_connector
|
self._vector_store_connector = vector_store_connector
|
||||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||||
|
|
||||||
def _retrieve(self, query: str) -> List[Chunk]:
|
def _retrieve(
|
||||||
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks.
|
"""Retrieve knowledge chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
|
filters: metadata filters.
|
||||||
Return:
|
Return:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
queries = [query]
|
queries = [query]
|
||||||
candidates = [
|
candidates = [
|
||||||
self._vector_store_connector.similar_search(query, self._top_k)
|
self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||||
for query in queries
|
for query in queries
|
||||||
]
|
]
|
||||||
res_candidates = cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
res_candidates = cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||||
return res_candidates
|
return res_candidates
|
||||||
|
|
||||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
def _retrieve_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks with score.
|
"""Retrieve knowledge chunks with score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
score_threshold (float): score threshold
|
score_threshold (float): score threshold
|
||||||
|
filters: metadata filters.
|
||||||
Return:
|
Return:
|
||||||
List[Chunk]: list of chunks with score
|
List[Chunk]: list of chunks with score
|
||||||
"""
|
"""
|
||||||
queries = [query]
|
queries = [query]
|
||||||
candidates_with_score = [
|
candidates_with_score = [
|
||||||
self._vector_store_connector.similar_search_with_scores(
|
self._vector_store_connector.similar_search_with_scores(
|
||||||
query, self._top_k, score_threshold
|
query, self._top_k, score_threshold, filters
|
||||||
)
|
)
|
||||||
for query in queries
|
for query in queries
|
||||||
]
|
]
|
||||||
@ -105,37 +113,44 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
new_candidates_with_score = self._rerank.rank(new_candidates_with_score)
|
new_candidates_with_score = self._rerank.rank(new_candidates_with_score)
|
||||||
return new_candidates_with_score
|
return new_candidates_with_score
|
||||||
|
|
||||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
async def _aretrieve(
|
||||||
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks.
|
"""Retrieve knowledge chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text.
|
||||||
|
filters: metadata filters.
|
||||||
Return:
|
Return:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
queries = [query]
|
queries = [query]
|
||||||
if self._query_rewrite:
|
if self._query_rewrite:
|
||||||
candidates_tasks = [self._similarity_search(query) for query in queries]
|
candidates_tasks = [
|
||||||
|
self._similarity_search(query, filters) for query in queries
|
||||||
|
]
|
||||||
chunks = await self._run_async_tasks(candidates_tasks)
|
chunks = await self._run_async_tasks(candidates_tasks)
|
||||||
context = "\n".join([chunk.content for chunk in chunks])
|
context = "\n".join([chunk.content for chunk in chunks])
|
||||||
new_queries = await self._query_rewrite.rewrite(
|
new_queries = await self._query_rewrite.rewrite(
|
||||||
origin_query=query, context=context, nums=1
|
origin_query=query, context=context, nums=1
|
||||||
)
|
)
|
||||||
queries.extend(new_queries)
|
queries.extend(new_queries)
|
||||||
candidates = [self._similarity_search(query) for query in queries]
|
candidates = [self._similarity_search(query, filters) for query in queries]
|
||||||
new_candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
|
new_candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
|
||||||
return new_candidates
|
return new_candidates
|
||||||
|
|
||||||
async def _aretrieve_with_score(
|
async def _aretrieve_with_score(
|
||||||
self, query: str, score_threshold: float
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks with score.
|
"""Retrieve knowledge chunks with score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
score_threshold (float): score threshold
|
score_threshold (float): score threshold
|
||||||
|
filters: metadata filters.
|
||||||
Return:
|
Return:
|
||||||
List[Chunk]: list of chunks with score
|
List[Chunk]: list of chunks with score
|
||||||
"""
|
"""
|
||||||
@ -145,7 +160,9 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
"EmbeddingRetriever.query_rewrite.similarity_search",
|
"EmbeddingRetriever.query_rewrite.similarity_search",
|
||||||
metadata={"query": query, "score_threshold": score_threshold},
|
metadata={"query": query, "score_threshold": score_threshold},
|
||||||
):
|
):
|
||||||
candidates_tasks = [self._similarity_search(query) for query in queries]
|
candidates_tasks = [
|
||||||
|
self._similarity_search(query, filters) for query in queries
|
||||||
|
]
|
||||||
chunks = await self._run_async_tasks(candidates_tasks)
|
chunks = await self._run_async_tasks(candidates_tasks)
|
||||||
context = "\n".join([chunk.content for chunk in chunks])
|
context = "\n".join([chunk.content for chunk in chunks])
|
||||||
with root_tracer.start_span(
|
with root_tracer.start_span(
|
||||||
@ -162,7 +179,7 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
metadata={"query": query, "score_threshold": score_threshold},
|
metadata={"query": query, "score_threshold": score_threshold},
|
||||||
):
|
):
|
||||||
candidates_with_score = [
|
candidates_with_score = [
|
||||||
self._similarity_search_with_score(query, score_threshold)
|
self._similarity_search_with_score(query, score_threshold, filters)
|
||||||
for query in queries
|
for query in queries
|
||||||
]
|
]
|
||||||
res_candidates_with_score = await run_async_tasks(
|
res_candidates_with_score = await run_async_tasks(
|
||||||
@ -183,12 +200,11 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
new_candidates_with_score = self._rerank.rank(new_candidates_with_score)
|
new_candidates_with_score = self._rerank.rank(new_candidates_with_score)
|
||||||
return new_candidates_with_score
|
return new_candidates_with_score
|
||||||
|
|
||||||
async def _similarity_search(self, query) -> List[Chunk]:
|
async def _similarity_search(
|
||||||
|
self, query, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Similar search."""
|
"""Similar search."""
|
||||||
return self._vector_store_connector.similar_search(
|
return self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||||
query,
|
|
||||||
self._top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _run_async_tasks(self, tasks) -> List[Chunk]:
|
async def _run_async_tasks(self, tasks) -> List[Chunk]:
|
||||||
"""Run async tasks."""
|
"""Run async tasks."""
|
||||||
@ -197,9 +213,9 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
return cast(List[Chunk], candidates)
|
return cast(List[Chunk], candidates)
|
||||||
|
|
||||||
async def _similarity_search_with_score(
|
async def _similarity_search_with_score(
|
||||||
self, query, score_threshold
|
self, query, score_threshold, filters: Optional[MetadataFilters] = None
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
"""Similar search with score."""
|
"""Similar search with score."""
|
||||||
return self._vector_store_connector.similar_search_with_scores(
|
return self._vector_store_connector.similar_search_with_scores(
|
||||||
query, self._top_k, score_threshold
|
query, self._top_k, score_threshold, filters
|
||||||
)
|
)
|
||||||
|
@ -36,5 +36,4 @@ def test_retrieve(query, top_k, mock_vector_store_connector, embedding_retriever
|
|||||||
|
|
||||||
retrieved_chunks = embedding_retriever._retrieve(query)
|
retrieved_chunks = embedding_retriever._retrieve(query)
|
||||||
|
|
||||||
mock_vector_store_connector.similar_search.assert_called_once_with(query, top_k)
|
|
||||||
assert len(retrieved_chunks) == top_k
|
assert len(retrieved_chunks) == top_k
|
||||||
|
@ -457,12 +457,12 @@ class MarkdownHeaderTextSplitter(TextSplitter):
|
|||||||
# Output line-by-line or aggregated into chunks w/ common headers
|
# Output line-by-line or aggregated into chunks w/ common headers
|
||||||
if headers_to_split_on is None:
|
if headers_to_split_on is None:
|
||||||
headers_to_split_on = [
|
headers_to_split_on = [
|
||||||
("#", "Header 1"),
|
("#", "Header1"),
|
||||||
("##", "Header 2"),
|
("##", "Header2"),
|
||||||
("###", "Header 3"),
|
("###", "Header3"),
|
||||||
("####", "Header 4"),
|
("####", "Header4"),
|
||||||
("#####", "Header 5"),
|
("#####", "Header5"),
|
||||||
("######", "Header 6"),
|
("######", "Header6"),
|
||||||
]
|
]
|
||||||
if filters is None:
|
if filters is None:
|
||||||
filters = []
|
filters = []
|
||||||
|
@ -7,6 +7,7 @@ from dbgpt.core import Chunk
|
|||||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||||
from dbgpt.rag.retriever.base import BaseRetriever
|
from dbgpt.rag.retriever.base import BaseRetriever
|
||||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -46,52 +47,81 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
).create()
|
).create()
|
||||||
|
|
||||||
def _retrieve(self, query: str) -> List[Chunk]:
|
def _retrieve(
|
||||||
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks.
|
"""Retrieve knowledge chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text.
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
candidates = self._vector_store_connector.similar_search(
|
candidates = self._vector_store_connector.similar_search(
|
||||||
doc=query, topk=self._top_k
|
doc=query, topk=self._top_k, filters=filters
|
||||||
)
|
)
|
||||||
return candidates
|
return candidates
|
||||||
|
|
||||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
def _retrieve_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks with score.
|
"""Retrieve knowledge chunks with score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text
|
||||||
score_threshold (float): score threshold
|
score_threshold (float): score threshold
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
List[Chunk]: list of chunks with score
|
List[Chunk]: list of chunks with score
|
||||||
"""
|
"""
|
||||||
candidates_with_score = self._vector_store_connector.similar_search_with_scores(
|
candidates_with_score = self._vector_store_connector.similar_search_with_scores(
|
||||||
doc=query, topk=self._top_k, score_threshold=score_threshold
|
doc=query,
|
||||||
|
topk=self._top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
filters=filters,
|
||||||
)
|
)
|
||||||
return candidates_with_score
|
return candidates_with_score
|
||||||
|
|
||||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
async def _aretrieve(
|
||||||
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks.
|
"""Retrieve knowledge chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text.
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
candidates = await blocking_func_to_async(self._executor, self._retrieve, query)
|
candidates = await blocking_func_to_async(
|
||||||
|
self._executor, self._retrieve, query, filters
|
||||||
|
)
|
||||||
return candidates
|
return candidates
|
||||||
|
|
||||||
async def _aretrieve_with_score(
|
async def _aretrieve_with_score(
|
||||||
self, query: str, score_threshold: float
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
"""Retrieve knowledge chunks with score.
|
"""Retrieve knowledge chunks with score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query text
|
query (str): query text.
|
||||||
score_threshold (float): score threshold
|
score_threshold (float): score threshold.
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
List[Chunk]: list of chunks with score
|
List[Chunk]: list of chunks with score.
|
||||||
"""
|
"""
|
||||||
candidates_with_score = await blocking_func_to_async(
|
candidates_with_score = await blocking_func_to_async(
|
||||||
self._executor, self._retrieve_with_score, query, score_threshold
|
self._executor, self._retrieve_with_score, query, score_threshold, filters
|
||||||
)
|
)
|
||||||
return candidates_with_score
|
return candidates_with_score
|
||||||
|
@ -4,11 +4,12 @@ import math
|
|||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel, Field
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
from dbgpt.core import Chunk, Embeddings
|
from dbgpt.core import Chunk, Embeddings
|
||||||
from dbgpt.core.awel.flow import Parameter
|
from dbgpt.core.awel.flow import Parameter
|
||||||
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
from dbgpt.util.i18n_utils import _
|
from dbgpt.util.i18n_utils import _
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -176,13 +177,15 @@ class VectorStoreBase(ABC):
|
|||||||
return ids
|
return ids
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def similar_search(self, text: str, topk: int) -> List[Chunk]:
|
def similar_search(
|
||||||
|
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Similar search in vector database.
|
"""Similar search in vector database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text(str): The query text.
|
text(str): The query text.
|
||||||
topk(int): The number of similar documents to return.
|
topk(int): The number of similar documents to return.
|
||||||
|
filters(Optional[MetadataFilters]): metadata filters.
|
||||||
Return:
|
Return:
|
||||||
List[Chunk]: The similar documents.
|
List[Chunk]: The similar documents.
|
||||||
"""
|
"""
|
||||||
@ -190,7 +193,11 @@ class VectorStoreBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def similar_search_with_scores(
|
def similar_search_with_scores(
|
||||||
self, text, topk, score_threshold: float
|
self,
|
||||||
|
text,
|
||||||
|
topk,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
"""Similar search with scores in vector database.
|
"""Similar search with scores in vector database.
|
||||||
|
|
||||||
@ -199,6 +206,7 @@ class VectorStoreBase(ABC):
|
|||||||
topk(int): The number of similar documents to return.
|
topk(int): The number of similar documents to return.
|
||||||
score_threshold(int): score_threshold: Optional, a floating point value
|
score_threshold(int): score_threshold: Optional, a floating point value
|
||||||
between 0 to 1
|
between 0 to 1
|
||||||
|
filters(Optional[MetadataFilters]): metadata filters.
|
||||||
Return:
|
Return:
|
||||||
List[Chunk]: The similar documents.
|
List[Chunk]: The similar documents.
|
||||||
"""
|
"""
|
||||||
@ -223,6 +231,15 @@ class VectorStoreBase(ABC):
|
|||||||
Args:
|
Args:
|
||||||
vector_name(str): The name of vector to delete.
|
vector_name(str): The name of vector to delete.
|
||||||
"""
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def convert_metadata_filters(self, filters: MetadataFilters) -> Any:
|
||||||
|
"""Convert metadata filters to vector store filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def _normalization_vectors(self, vectors):
|
def _normalization_vectors(self, vectors):
|
||||||
"""Return L2-normalization vectors to scale[0,1].
|
"""Return L2-normalization vectors to scale[0,1].
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Chroma vector store."""
|
"""Chroma vector store."""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, List
|
from typing import List, Optional
|
||||||
|
|
||||||
from chromadb import PersistentClient
|
from chromadb import PersistentClient
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
@ -13,6 +13,7 @@ from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
|||||||
from dbgpt.util.i18n_utils import _
|
from dbgpt.util.i18n_utils import _
|
||||||
|
|
||||||
from .base import _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig
|
from .base import _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig
|
||||||
|
from .filters import FilterOperator, MetadataFilters
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -86,16 +87,23 @@ class ChromaStore(VectorStoreBase):
|
|||||||
collection_metadata=collection_metadata,
|
collection_metadata=collection_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def similar_search(self, text, topk, **kwargs: Any) -> List[Chunk]:
|
def similar_search(
|
||||||
|
self, text, topk, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Search similar documents."""
|
"""Search similar documents."""
|
||||||
logger.info("ChromaStore similar search")
|
logger.info("ChromaStore similar search")
|
||||||
lc_documents = self.vector_store_client.similarity_search(text, topk, **kwargs)
|
where_filters = self.convert_metadata_filters(filters) if filters else None
|
||||||
|
lc_documents = self.vector_store_client.similarity_search(
|
||||||
|
text, topk, filter=where_filters
|
||||||
|
)
|
||||||
return [
|
return [
|
||||||
Chunk(content=doc.page_content, metadata=doc.metadata)
|
Chunk(content=doc.page_content, metadata=doc.metadata)
|
||||||
for doc in lc_documents
|
for doc in lc_documents
|
||||||
]
|
]
|
||||||
|
|
||||||
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]:
|
def similar_search_with_scores(
|
||||||
|
self, text, topk, score_threshold, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Search similar documents with scores.
|
"""Search similar documents with scores.
|
||||||
|
|
||||||
Chroma similar_search_with_score.
|
Chroma similar_search_with_score.
|
||||||
@ -106,11 +114,16 @@ class ChromaStore(VectorStoreBase):
|
|||||||
score_threshold(float): score_threshold: Optional, a floating point value
|
score_threshold(float): score_threshold: Optional, a floating point value
|
||||||
between 0 to 1 to filter the resulting set of retrieved docs,0 is
|
between 0 to 1 to filter the resulting set of retrieved docs,0 is
|
||||||
dissimilar, 1 is most similar.
|
dissimilar, 1 is most similar.
|
||||||
|
filters(MetadataFilters): metadata filters, defaults to None
|
||||||
"""
|
"""
|
||||||
logger.info("ChromaStore similar search with scores")
|
logger.info("ChromaStore similar search with scores")
|
||||||
|
where_filters = self.convert_metadata_filters(filters) if filters else None
|
||||||
docs_and_scores = (
|
docs_and_scores = (
|
||||||
self.vector_store_client.similarity_search_with_relevance_scores(
|
self.vector_store_client.similarity_search_with_relevance_scores(
|
||||||
query=text, k=topk, score_threshold=score_threshold
|
query=text,
|
||||||
|
k=topk,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
filter=where_filters,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
@ -152,10 +165,71 @@ class ChromaStore(VectorStoreBase):
|
|||||||
collection = self.vector_store_client._collection
|
collection = self.vector_store_client._collection
|
||||||
collection.delete(ids=ids)
|
collection.delete(ids=ids)
|
||||||
|
|
||||||
|
def convert_metadata_filters(
|
||||||
|
self,
|
||||||
|
filters: MetadataFilters,
|
||||||
|
) -> dict:
|
||||||
|
"""Convert metadata filters to Chroma filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters(MetadataFilters): metadata filters.
|
||||||
|
Returns:
|
||||||
|
dict: Chroma filters.
|
||||||
|
"""
|
||||||
|
where_filters = {}
|
||||||
|
filters_list = []
|
||||||
|
condition = filters.condition
|
||||||
|
chroma_condition = f"${condition}"
|
||||||
|
if filters.filters:
|
||||||
|
for filter in filters.filters:
|
||||||
|
if filter.operator:
|
||||||
|
filters_list.append(
|
||||||
|
{
|
||||||
|
filter.key: {
|
||||||
|
_convert_chroma_filter_operator(
|
||||||
|
filter.operator
|
||||||
|
): filter.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filters_list.append({filter.key: filter.value}) # type: ignore
|
||||||
|
|
||||||
|
if len(filters_list) == 1:
|
||||||
|
return filters_list[0]
|
||||||
|
elif len(filters_list) > 1:
|
||||||
|
where_filters[chroma_condition] = filters_list
|
||||||
|
return where_filters
|
||||||
|
|
||||||
def _clean_persist_folder(self):
|
def _clean_persist_folder(self):
|
||||||
|
"""Clean persist folder."""
|
||||||
for root, dirs, files in os.walk(self.persist_dir, topdown=False):
|
for root, dirs, files in os.walk(self.persist_dir, topdown=False):
|
||||||
for name in files:
|
for name in files:
|
||||||
os.remove(os.path.join(root, name))
|
os.remove(os.path.join(root, name))
|
||||||
for name in dirs:
|
for name in dirs:
|
||||||
os.rmdir(os.path.join(root, name))
|
os.rmdir(os.path.join(root, name))
|
||||||
os.rmdir(self.persist_dir)
|
os.rmdir(self.persist_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_chroma_filter_operator(operator: str) -> str:
|
||||||
|
"""Convert operator to Chroma where operator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operator(str): operator.
|
||||||
|
Returns:
|
||||||
|
str: Chroma where operator.
|
||||||
|
"""
|
||||||
|
if operator == FilterOperator.EQ:
|
||||||
|
return "$eq"
|
||||||
|
elif operator == FilterOperator.NE:
|
||||||
|
return "$ne"
|
||||||
|
elif operator == FilterOperator.GT:
|
||||||
|
return "$gt"
|
||||||
|
elif operator == FilterOperator.LT:
|
||||||
|
return "$lt"
|
||||||
|
elif operator == FilterOperator.GTE:
|
||||||
|
return "$gte"
|
||||||
|
elif operator == FilterOperator.LTE:
|
||||||
|
return "$lte"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Chroma Where operator {operator} not supported")
|
||||||
|
@ -13,6 +13,7 @@ from dbgpt.core.awel.flow import (
|
|||||||
)
|
)
|
||||||
from dbgpt.storage import vector_store
|
from dbgpt.storage import vector_store
|
||||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||||
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
from dbgpt.util.i18n_utils import _
|
from dbgpt.util.i18n_utils import _
|
||||||
|
|
||||||
connector: Dict[str, Type] = {}
|
connector: Dict[str, Type] = {}
|
||||||
@ -128,23 +129,29 @@ class VectorStoreConnector:
|
|||||||
max_threads,
|
max_threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
|
def similar_search(
|
||||||
|
self, doc: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Similar search in vector database.
|
"""Similar search in vector database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- doc: query text
|
- doc: query text
|
||||||
- topk: topk
|
- topk: topk
|
||||||
|
- filters: metadata filters.
|
||||||
Return:
|
Return:
|
||||||
- chunks: chunks.
|
- chunks: chunks.
|
||||||
"""
|
"""
|
||||||
return self.client.similar_search(doc, topk)
|
return self.client.similar_search(doc, topk, filters)
|
||||||
|
|
||||||
def similar_search_with_scores(
|
def similar_search_with_scores(
|
||||||
self, doc: str, topk: int, score_threshold: float
|
self,
|
||||||
|
doc: str,
|
||||||
|
topk: int,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
"""Similar search with scores in vector database.
|
"""Similar_search_with_score in vector database.
|
||||||
|
|
||||||
similar_search_with_score in vector database..
|
|
||||||
Return docs and relevance scores in the range [0, 1].
|
Return docs and relevance scores in the range [0, 1].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -153,10 +160,13 @@ class VectorStoreConnector:
|
|||||||
score_threshold(float): score_threshold: Optional, a floating point value
|
score_threshold(float): score_threshold: Optional, a floating point value
|
||||||
between 0 to 1 to filter the resulting set of retrieved docs,0 is
|
between 0 to 1 to filter the resulting set of retrieved docs,0 is
|
||||||
dissimilar, 1 is most similar.
|
dissimilar, 1 is most similar.
|
||||||
|
filters: metadata filters.
|
||||||
Return:
|
Return:
|
||||||
- chunks: chunks.
|
- chunks: Return docs and relevance scores in the range [0, 1].
|
||||||
"""
|
"""
|
||||||
return self.client.similar_search_with_scores(doc, topk, score_threshold)
|
return self.client.similar_search_with_scores(
|
||||||
|
doc, topk, score_threshold, filters
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vector_store_config(self) -> VectorStoreConfig:
|
def vector_store_config(self) -> VectorStoreConfig:
|
||||||
|
56
dbgpt/storage/vector_store/filters.py
Normal file
56
dbgpt/storage/vector_store/filters.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
"""Vector Store Meta data filters."""
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class FilterOperator(str, Enum):
|
||||||
|
"""Meta data filter operator."""
|
||||||
|
|
||||||
|
EQ = "=="
|
||||||
|
GT = ">"
|
||||||
|
LT = "<"
|
||||||
|
NE = "!="
|
||||||
|
GTE = ">="
|
||||||
|
LTE = "<="
|
||||||
|
IN = "in"
|
||||||
|
NIN = "nin"
|
||||||
|
EXISTS = "exists"
|
||||||
|
|
||||||
|
|
||||||
|
class FilterCondition(str, Enum):
|
||||||
|
"""Vector Store Meta data filter conditions."""
|
||||||
|
|
||||||
|
AND = "and"
|
||||||
|
OR = "or"
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataFilter(BaseModel):
|
||||||
|
"""Meta data filter."""
|
||||||
|
|
||||||
|
key: str = Field(
|
||||||
|
...,
|
||||||
|
description="The key of metadata to filter.",
|
||||||
|
)
|
||||||
|
operator: FilterOperator = Field(
|
||||||
|
default=FilterOperator.EQ,
|
||||||
|
description="The operator of metadata filter.",
|
||||||
|
)
|
||||||
|
value: Union[str, int, float, List[str], List[int], List[float]] = Field(
|
||||||
|
...,
|
||||||
|
description="The value of metadata to filter.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataFilters(BaseModel):
|
||||||
|
"""Meta data filters."""
|
||||||
|
|
||||||
|
condition: FilterCondition = Field(
|
||||||
|
default=FilterCondition.AND,
|
||||||
|
description="The condition of metadata filters.",
|
||||||
|
)
|
||||||
|
filters: List[MetadataFilter] = Field(
|
||||||
|
...,
|
||||||
|
description="The metadata filters.",
|
||||||
|
)
|
@ -14,6 +14,7 @@ from dbgpt.storage.vector_store.base import (
|
|||||||
VectorStoreBase,
|
VectorStoreBase,
|
||||||
VectorStoreConfig,
|
VectorStoreConfig,
|
||||||
)
|
)
|
||||||
|
from dbgpt.storage.vector_store.filters import FilterOperator, MetadataFilters
|
||||||
from dbgpt.util import string_utils
|
from dbgpt.util import string_utils
|
||||||
from dbgpt.util.i18n_utils import _
|
from dbgpt.util.i18n_utils import _
|
||||||
|
|
||||||
@ -206,6 +207,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self.vector_field = milvus_vector_config.get("embedding_field") or "vector"
|
self.vector_field = milvus_vector_config.get("embedding_field") or "vector"
|
||||||
self.text_field = milvus_vector_config.get("text_field") or "content"
|
self.text_field = milvus_vector_config.get("text_field") or "content"
|
||||||
self.metadata_field = milvus_vector_config.get("metadata_field") or "metadata"
|
self.metadata_field = milvus_vector_config.get("metadata_field") or "metadata"
|
||||||
|
self.props_field = milvus_vector_config.get("props_field") or "props_field"
|
||||||
|
|
||||||
if (self.username is None) != (self.password is None):
|
if (self.username is None) != (self.password is None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -284,6 +286,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
vector_field = self.vector_field
|
vector_field = self.vector_field
|
||||||
text_field = self.text_field
|
text_field = self.text_field
|
||||||
metadata_field = self.metadata_field
|
metadata_field = self.metadata_field
|
||||||
|
props_field = self.props_field
|
||||||
# self.text_field = text_field
|
# self.text_field = text_field
|
||||||
collection_name = vector_name
|
collection_name = vector_name
|
||||||
fields = []
|
fields = []
|
||||||
@ -300,6 +303,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
||||||
|
|
||||||
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
|
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
|
||||||
|
fields.append(FieldSchema(props_field, DataType.JSON))
|
||||||
schema = CollectionSchema(fields)
|
schema = CollectionSchema(fields)
|
||||||
# Create the collection
|
# Create the collection
|
||||||
collection = Collection(collection_name, schema)
|
collection = Collection(collection_name, schema)
|
||||||
@ -346,6 +350,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
for d in metadatas:
|
for d in metadatas:
|
||||||
# for key, value in d.items():
|
# for key, value in d.items():
|
||||||
insert_dict.setdefault("metadata", []).append(json.dumps(d))
|
insert_dict.setdefault("metadata", []).append(json.dumps(d))
|
||||||
|
insert_dict.setdefault("props_field", []).append(d)
|
||||||
# Convert dict to list of lists for insertion
|
# Convert dict to list of lists for insertion
|
||||||
insert_list = [insert_dict[x] for x in self.fields]
|
insert_list = [insert_dict[x] for x in self.fields]
|
||||||
# Insert into the collection.
|
# Insert into the collection.
|
||||||
@ -368,7 +373,9 @@ class MilvusStore(VectorStoreBase):
|
|||||||
doc_ids = [str(doc_id) for doc_id in doc_ids]
|
doc_ids = [str(doc_id) for doc_id in doc_ids]
|
||||||
return doc_ids
|
return doc_ids
|
||||||
|
|
||||||
def similar_search(self, text, topk) -> List[Chunk]:
|
def similar_search(
|
||||||
|
self, text, topk, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Perform a search on a query string and return results."""
|
"""Perform a search on a query string and return results."""
|
||||||
from pymilvus import Collection, DataType
|
from pymilvus import Collection, DataType
|
||||||
|
|
||||||
@ -383,7 +390,9 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self.primary_field = x.name
|
self.primary_field = x.name
|
||||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||||
self.vector_field = x.name
|
self.vector_field = x.name
|
||||||
_, docs_and_scores = self._search(text, topk)
|
# convert to milvus expr filter.
|
||||||
|
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
|
||||||
|
_, docs_and_scores = self._search(text, topk, expr=milvus_filter_expr)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
Chunk(
|
Chunk(
|
||||||
@ -393,7 +402,13 @@ class MilvusStore(VectorStoreBase):
|
|||||||
for doc, _, _ in docs_and_scores
|
for doc, _, _ in docs_and_scores
|
||||||
]
|
]
|
||||||
|
|
||||||
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]:
|
def similar_search_with_scores(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
topk: int,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Perform a search on a query string and return results with score.
|
"""Perform a search on a query string and return results with score.
|
||||||
|
|
||||||
For more information about the search parameters, take a look at the pymilvus
|
For more information about the search parameters, take a look at the pymilvus
|
||||||
@ -401,15 +416,10 @@ class MilvusStore(VectorStoreBase):
|
|||||||
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
|
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding (List[float]): The embedding vector being searched.
|
text (str): The query text.
|
||||||
k (int, optional): The amount of results to return. Defaults to 4.
|
topk (int): The number of similar documents to return.
|
||||||
param (dict): The search params for the specified index.
|
score_threshold (float): Optional, a floating point value between 0 to 1.
|
||||||
Defaults to None.
|
filters (Optional[MetadataFilters]): Optional, metadata filters.
|
||||||
expr (str, optional): Filtering expression. Defaults to None.
|
|
||||||
timeout (int, optional): How long to wait before timeout error.
|
|
||||||
Defaults to None.
|
|
||||||
kwargs: Collection.search() keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[Document, float]]: Result doc and score.
|
List[Tuple[Document, float]]: Result doc and score.
|
||||||
"""
|
"""
|
||||||
@ -427,7 +437,11 @@ class MilvusStore(VectorStoreBase):
|
|||||||
|
|
||||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||||
self.vector_field = x.name
|
self.vector_field = x.name
|
||||||
_, docs_and_scores = self._search(text, topk)
|
# convert to milvus expr filter.
|
||||||
|
milvus_filter_expr = self.convert_metadata_filters(filters) if filters else None
|
||||||
|
_, docs_and_scores = self._search(
|
||||||
|
query=text, topk=topk, expr=milvus_filter_expr
|
||||||
|
)
|
||||||
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
|
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"similarity score need between" f" 0 and 1, got {docs_and_scores}"
|
"similarity score need between" f" 0 and 1, got {docs_and_scores}"
|
||||||
@ -462,6 +476,20 @@ class MilvusStore(VectorStoreBase):
|
|||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
|
"""Search in vector database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: query text.
|
||||||
|
k: topk.
|
||||||
|
param: search params.
|
||||||
|
expr: search expr.
|
||||||
|
partition_names: partition names.
|
||||||
|
round_decimal: round decimal.
|
||||||
|
timeout: timeout.
|
||||||
|
**kwargs: kwargs.
|
||||||
|
Returns:
|
||||||
|
Tuple[Document, float, int]: Result doc and score.
|
||||||
|
"""
|
||||||
self.col.load()
|
self.col.load()
|
||||||
# use default index params.
|
# use default index params.
|
||||||
if param is None:
|
if param is None:
|
||||||
@ -495,7 +523,9 @@ class MilvusStore(VectorStoreBase):
|
|||||||
result.id,
|
result.id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if len(ret) == 0:
|
||||||
|
logger.warning("No relevant docs were retrieved.")
|
||||||
|
return None, []
|
||||||
return ret[0], ret
|
return ret[0], ret
|
||||||
|
|
||||||
def vector_name_exists(self):
|
def vector_name_exists(self):
|
||||||
@ -523,6 +553,40 @@ class MilvusStore(VectorStoreBase):
|
|||||||
logger.info(f"begin delete milvus ids: {ids}")
|
logger.info(f"begin delete milvus ids: {ids}")
|
||||||
delete_ids = ids.split(",")
|
delete_ids = ids.split(",")
|
||||||
doc_ids = [int(doc_id) for doc_id in delete_ids]
|
doc_ids = [int(doc_id) for doc_id in delete_ids]
|
||||||
delet_expr = f"{self.primary_field} in {doc_ids}"
|
delete_expr = f"{self.primary_field} in {doc_ids}"
|
||||||
self.col.delete(delet_expr)
|
self.col.delete(delete_expr)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def convert_metadata_filters(self, filters: MetadataFilters) -> str:
|
||||||
|
"""Convert filter to milvus filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- filters: metadata filters.
|
||||||
|
Returns:
|
||||||
|
- metadata_filters: metadata filters.
|
||||||
|
"""
|
||||||
|
metadata_filters = []
|
||||||
|
for metadata_filter in filters.filters:
|
||||||
|
if isinstance(metadata_filter.value, str):
|
||||||
|
expr = (
|
||||||
|
f"{self.props_field}['{metadata_filter.key}'] "
|
||||||
|
f"{FilterOperator.EQ} '{metadata_filter.value}'"
|
||||||
|
)
|
||||||
|
metadata_filters.append(expr)
|
||||||
|
elif isinstance(metadata_filter.value, List):
|
||||||
|
expr = (
|
||||||
|
f"{self.props_field}['{metadata_filter.key}'] "
|
||||||
|
f"{FilterOperator.IN} {metadata_filter.value}"
|
||||||
|
)
|
||||||
|
metadata_filters.append(expr)
|
||||||
|
else:
|
||||||
|
expr = (
|
||||||
|
f"{self.props_field}['{metadata_filter.key}'] "
|
||||||
|
f"{FilterOperator.EQ} {str(metadata_filter.value)}"
|
||||||
|
)
|
||||||
|
metadata_filters.append(expr)
|
||||||
|
if len(metadata_filters) > 1:
|
||||||
|
metadata_filter_expr = f" {filters.condition} ".join(metadata_filters)
|
||||||
|
else:
|
||||||
|
metadata_filter_expr = metadata_filters[0]
|
||||||
|
return metadata_filter_expr
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Postgres vector store."""
|
"""Postgres vector store."""
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, List
|
from typing import List, Optional
|
||||||
|
|
||||||
from dbgpt._private.pydantic import Field
|
from dbgpt._private.pydantic import Field
|
||||||
from dbgpt.core import Chunk
|
from dbgpt.core import Chunk
|
||||||
@ -10,6 +10,7 @@ from dbgpt.storage.vector_store.base import (
|
|||||||
VectorStoreBase,
|
VectorStoreBase,
|
||||||
VectorStoreConfig,
|
VectorStoreConfig,
|
||||||
)
|
)
|
||||||
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
from dbgpt.util.i18n_utils import _
|
from dbgpt.util.i18n_utils import _
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -70,9 +71,11 @@ class PGVectorStore(VectorStoreBase):
|
|||||||
connection_string=self.connection_string,
|
connection_string=self.connection_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
def similar_search(self, text: str, topk: int, **kwargs: Any) -> List[Chunk]:
|
def similar_search(
|
||||||
|
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Perform similar search in PGVector."""
|
"""Perform similar search in PGVector."""
|
||||||
return self.vector_store_client.similarity_search(text, topk)
|
return self.vector_store_client.similarity_search(text, topk, filters)
|
||||||
|
|
||||||
def vector_name_exists(self) -> bool:
|
def vector_name_exists(self) -> bool:
|
||||||
"""Check if vector name exists."""
|
"""Check if vector name exists."""
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Weaviate vector store."""
|
"""Weaviate vector store."""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from dbgpt._private.pydantic import Field
|
from dbgpt._private.pydantic import Field
|
||||||
from dbgpt.core import Chunk
|
from dbgpt.core import Chunk
|
||||||
@ -9,6 +9,7 @@ from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
|||||||
from dbgpt.util.i18n_utils import _
|
from dbgpt.util.i18n_utils import _
|
||||||
|
|
||||||
from .base import _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig
|
from .base import _COMMON_PARAMETERS, VectorStoreBase, VectorStoreConfig
|
||||||
|
from .filters import MetadataFilters
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -80,7 +81,9 @@ class WeaviateStore(VectorStoreBase):
|
|||||||
|
|
||||||
self.vector_store_client = weaviate.Client(self.weaviate_url)
|
self.vector_store_client = weaviate.Client(self.weaviate_url)
|
||||||
|
|
||||||
def similar_search(self, text: str, topk: int) -> List[Chunk]:
|
def similar_search(
|
||||||
|
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
"""Perform similar search in Weaviate."""
|
"""Perform similar search in Weaviate."""
|
||||||
logger.info("Weaviate similar search")
|
logger.info("Weaviate similar search")
|
||||||
# nearText = {
|
# nearText = {
|
||||||
|
60
examples/rag/metadata_filter_example.py
Normal file
60
examples/rag/metadata_filter_example.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
"""Rag Metadata Properties filter example.
|
||||||
|
pre-requirements:
|
||||||
|
make sure you have set your embedding model path in your example code.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
..code-block:: shell
|
||||||
|
python examples/rag/metadata_filter_example.py
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
|
||||||
|
from dbgpt.rag import ChunkParameters
|
||||||
|
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||||
|
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||||
|
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||||
|
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||||
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
|
||||||
|
|
||||||
|
|
||||||
|
def _create_vector_connector():
|
||||||
|
"""Create vector connector."""
|
||||||
|
return VectorStoreConnector.from_default(
|
||||||
|
"Chroma",
|
||||||
|
vector_store_config=ChromaVectorConfig(
|
||||||
|
name="example_metadata_filter_name",
|
||||||
|
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||||
|
),
|
||||||
|
embedding_fn=DefaultEmbeddingFactory(
|
||||||
|
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||||
|
).create(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||||
|
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||||
|
vector_connector = _create_vector_connector()
|
||||||
|
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_MARKDOWN_HEADER")
|
||||||
|
# get embedding assembler
|
||||||
|
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||||
|
knowledge=knowledge,
|
||||||
|
chunk_parameters=chunk_parameters,
|
||||||
|
vector_store_connector=vector_connector,
|
||||||
|
)
|
||||||
|
assembler.persist()
|
||||||
|
# get embeddings retriever
|
||||||
|
retriever = assembler.as_retriever(3)
|
||||||
|
# create metadata filter
|
||||||
|
metadata_filter = MetadataFilter(key="Header2", value="AWEL Design")
|
||||||
|
filters = MetadataFilters(filters=[metadata_filter])
|
||||||
|
chunks = await retriever.aretrieve_with_scores(
|
||||||
|
"what is awel talk about", 0.0, filters
|
||||||
|
)
|
||||||
|
print(f"embedding rag example results:{chunks}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
Loading…
Reference in New Issue
Block a user