Add dashvector self query retriever (#9684)

## Description
Add `Dashvector` retriever and self-query retriever

## How to use
```python
from langchain.vectorstores.dashvector import DashVector

vectorstore = DashVector.from_documents(docs, embeddings)
retriever = SelfQueryRetriever.from_llm(
    llm, vectorstore, document_content_description, metadata_field_info, verbose=True
)
```

---------

Co-authored-by: smallrain.xuxy <smallrain.xuxy@alibaba-inc.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Xiaoyu Xee
2023-09-04 11:51:04 +08:00
committed by GitHub
parent 056e59672b
commit 9bcfd58580
6 changed files with 577 additions and 2 deletions

View File

@@ -9,6 +9,7 @@ from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.retrievers.self_query.chroma import ChromaTranslator
from langchain.retrievers.self_query.dashvector import DashvectorTranslator
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator
from langchain.retrievers.self_query.myscale import MyScaleTranslator
@@ -19,6 +20,7 @@ from langchain.schema import BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.vectorstores import (
Chroma,
DashVector,
DeepLake,
ElasticsearchStore,
MyScale,
@@ -35,6 +37,7 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
Pinecone: PineconeTranslator,
Chroma: ChromaTranslator,
DashVector: DashvectorTranslator,
Weaviate: WeaviateTranslator,
Qdrant: QdrantTranslator,
MyScale: MyScaleTranslator,

View File

@@ -0,0 +1,64 @@
"""Logic for converting internal query language to a valid DashVector query."""
from typing import Tuple, Union
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class DashvectorTranslator(Visitor):
"""Logic for converting internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR]
allowed_comparators = [
Comparator.EQ,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.LIKE,
]
map_dict = {
Operator.AND: " AND ",
Operator.OR: " OR ",
Comparator.EQ: " = ",
Comparator.GT: " > ",
Comparator.GTE: " >= ",
Comparator.LT: " < ",
Comparator.LTE: " <= ",
Comparator.LIKE: " LIKE ",
}
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return self.map_dict[func]
def visit_operation(self, operation: Operation) -> str:
args = [arg.accept(self) for arg in operation.arguments]
return self._format_func(operation.operator).join(args)
def visit_comparison(self, comparison: Comparison) -> str:
value = comparison.value
if isinstance(value, str):
if comparison.comparator == Comparator.LIKE:
value = f"'%{value}%'"
else:
value = f"'{value}'"
return (
f"{comparison.attribute}{self._format_func(comparison.comparator)}{value}"
)
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

View File

@@ -1799,6 +1799,26 @@ files = [
{file = "cssselect-1.2.0.tar.gz", hash = "sha256:666b19839cfaddb9ce9d36bfe4c969132c647b92fc9088c4e23f786b30f1b3dc"},
]
[[package]]
name = "dashvector"
version = "1.0.1"
description = "DashVector Client Python Sdk Library"
category = "main"
optional = true
python-versions = ">=3.7.0"
files = [
{file = "dashvector-1.0.1-py3-none-any.whl", hash = "sha256:e2fc362c65979d55cf605fb90deca4a292c69e1c2101df22430c80db744591ad"},
]
[package.dependencies]
aiohttp = ">=3.1.0"
grpcio = [
{version = ">=1.22.0", markers = "python_version < \"3.11\""},
{version = ">=1.49.1", markers = "python_version >= \"3.11\""},
]
numpy = "*"
protobuf = ">=3.8.0,<4.0.0"
[[package]]
name = "dataclasses-json"
version = "0.5.9"
@@ -10897,7 +10917,7 @@ clarifai = ["clarifai"]
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
extended-testing = ["amazon-textract-caller", "assemblyai", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "jq", "pdfminer-six", "pgvector", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "mwparserfromhell", "mwxml", "pandas", "telethon", "psychicapi", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "streamlit", "pyspark", "openai", "sympy", "rapidfuzz", "openai", "rank-bm25", "geopandas", "jinja2", "gitpython", "newspaper3k", "feedparser", "xata", "xmltodict", "faiss-cpu", "openapi-schema-pydantic", "markdownify", "sqlite-vss"]
extended-testing = ["amazon-textract-caller", "assemblyai", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "jq", "pdfminer-six", "pgvector", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "mwparserfromhell", "mwxml", "pandas", "telethon", "psychicapi", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "streamlit", "pyspark", "openai", "sympy", "rapidfuzz", "openai", "rank-bm25", "geopandas", "jinja2", "gitpython", "newspaper3k", "feedparser", "xata", "xmltodict", "faiss-cpu", "openapi-schema-pydantic", "markdownify", "dashvector", "sqlite-vss"]
javascript = ["esprima"]
llms = ["clarifai", "cohere", "openai", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
openai = ["openai", "tiktoken"]
@@ -10907,4 +10927,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "47e048f7708139d5e5040c6d56ef4cb66153c3052a9237d6ea42eeb2565ad470"
content-hash = "b63078268a80c07577b432114302f4f86d47be25b83a245affb0dbc999fb2c1f"

View File

@@ -127,6 +127,7 @@ xata = {version = "^1.0.0a7", optional = true}
xmltodict = {version = "^0.13.0", optional = true}
markdownify = {version = "^0.11.6", optional = true}
assemblyai = {version = "^0.17.0", optional = true}
dashvector = {version = "^1.0.1", optional = true}
sqlite-vss = {version = "^0.1.2", optional = true}
@@ -342,6 +343,7 @@ extended_testing = [
"faiss-cpu",
"openapi-schema-pydantic",
"markdownify",
"dashvector",
"sqlite-vss",
]

View File

@@ -0,0 +1,52 @@
from typing import Any, Tuple
import pytest
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
)
from langchain.retrievers.self_query.dashvector import DashvectorTranslator
DEFAULT_TRANSLATOR = DashvectorTranslator()
@pytest.mark.parametrize(
"triplet",
[
(Comparator.EQ, 2, "foo = 2"),
(Comparator.LT, 2, "foo < 2"),
(Comparator.LTE, 2, "foo <= 2"),
(Comparator.GT, 2, "foo > 2"),
(Comparator.GTE, 2, "foo >= 2"),
(Comparator.LIKE, "bar", "foo LIKE '%bar%'"),
],
)
def test_visit_comparison(triplet: Tuple[Comparator, Any, str]) -> None:
comparator, value, expected = triplet
actual = DEFAULT_TRANSLATOR.visit_comparison(
Comparison(comparator=comparator, attribute="foo", value=value)
)
assert expected == actual
@pytest.mark.parametrize(
"triplet",
[
(Operator.AND, "foo < 2 AND bar = 'baz'"),
(Operator.OR, "foo < 2 OR bar = 'baz'"),
],
)
def test_visit_operation(triplet: Tuple[Operator, str]) -> None:
operator, expected = triplet
op = Operation(
operator=operator,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
actual = DEFAULT_TRANSLATOR.visit_operation(op)
assert expected == actual