mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
langchain[patch]: support more comparators in Milvus self-querying retriever (#16076)
- **Description:** Support IN and LIKE comparators in Milvus self-querying retriever, based on [Boolean Expression Rules](https://milvus.io/docs/boolean.md) - **Issue:** No - **Dependencies:** No - **Twitter handle:** No Signed-off-by: ChengZi <chen.zhang@zilliz.com>
This commit is contained in:
parent
9c2f1f07a0
commit
8597484195
@ -16,24 +16,32 @@ COMPARATOR_TO_BER = {
|
|||||||
Comparator.GTE: ">=",
|
Comparator.GTE: ">=",
|
||||||
Comparator.LT: "<",
|
Comparator.LT: "<",
|
||||||
Comparator.LTE: "<=",
|
Comparator.LTE: "<=",
|
||||||
|
Comparator.IN: "in",
|
||||||
|
Comparator.LIKE: "like",
|
||||||
}
|
}
|
||||||
|
|
||||||
UNARY_OPERATORS = [Operator.NOT]
|
UNARY_OPERATORS = [Operator.NOT]
|
||||||
|
|
||||||
|
|
||||||
def process_value(value: Union[int, float, str]) -> str:
|
def process_value(value: Union[int, float, str], comparator: Comparator) -> str:
|
||||||
"""Convert a value to a string and add double quotes if it is a string.
|
"""Convert a value to a string and add double quotes if it is a string.
|
||||||
|
|
||||||
It required for comparators involving strings.
|
It required for comparators involving strings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value: The value to convert.
|
value: The value to convert.
|
||||||
|
comparator: The comparator.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The converted value as a string.
|
The converted value as a string.
|
||||||
"""
|
"""
|
||||||
#
|
#
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
|
if comparator is Comparator.LIKE:
|
||||||
|
# If the comparator is LIKE, add a percent sign after it for prefix matching
|
||||||
|
# and add double quotes
|
||||||
|
return f'"{value}%"'
|
||||||
|
else:
|
||||||
# If the value is already a string, add double quotes
|
# If the value is already a string, add double quotes
|
||||||
return f'"{value}"'
|
return f'"{value}"'
|
||||||
else:
|
else:
|
||||||
@ -54,6 +62,8 @@ class MilvusTranslator(Visitor):
|
|||||||
Comparator.GTE,
|
Comparator.GTE,
|
||||||
Comparator.LT,
|
Comparator.LT,
|
||||||
Comparator.LTE,
|
Comparator.LTE,
|
||||||
|
Comparator.IN,
|
||||||
|
Comparator.LIKE,
|
||||||
]
|
]
|
||||||
|
|
||||||
def _format_func(self, func: Union[Operator, Comparator]) -> str:
|
def _format_func(self, func: Union[Operator, Comparator]) -> str:
|
||||||
@ -78,7 +88,7 @@ class MilvusTranslator(Visitor):
|
|||||||
|
|
||||||
def visit_comparison(self, comparison: Comparison) -> str:
|
def visit_comparison(self, comparison: Comparison) -> str:
|
||||||
comparator = self._format_func(comparison.comparator)
|
comparator = self._format_func(comparison.comparator)
|
||||||
processed_value = process_value(comparison.value)
|
processed_value = process_value(comparison.value, comparison.comparator)
|
||||||
attribute = comparison.attribute
|
attribute = comparison.attribute
|
||||||
|
|
||||||
return "( " + attribute + " " + comparator + " " + processed_value + " )"
|
return "( " + attribute + " " + comparator + " " + processed_value + " )"
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from typing import Dict, Tuple
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain.chains.query_constructor.ir import (
|
from langchain.chains.query_constructor.ir import (
|
||||||
Comparator,
|
Comparator,
|
||||||
@ -12,11 +14,22 @@ from langchain.retrievers.self_query.milvus import MilvusTranslator
|
|||||||
DEFAULT_TRANSLATOR = MilvusTranslator()
|
DEFAULT_TRANSLATOR = MilvusTranslator()
|
||||||
|
|
||||||
|
|
||||||
def test_visit_comparison() -> None:
|
@pytest.mark.parametrize(
|
||||||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=4)
|
"triplet",
|
||||||
expected = "( foo < 4 )"
|
[
|
||||||
|
(Comparator.EQ, 2, "( foo == 2 )"),
|
||||||
|
(Comparator.GT, 2, "( foo > 2 )"),
|
||||||
|
(Comparator.GTE, 2, "( foo >= 2 )"),
|
||||||
|
(Comparator.LT, 2, "( foo < 2 )"),
|
||||||
|
(Comparator.LTE, 2, "( foo <= 2 )"),
|
||||||
|
(Comparator.IN, ["bar", "abc"], "( foo in ['bar', 'abc'] )"),
|
||||||
|
(Comparator.LIKE, "bar", '( foo like "bar%" )'),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_visit_comparison(triplet: Tuple[Comparator, Any, str]) -> None:
|
||||||
|
comparator, value, expected = triplet
|
||||||
|
comp = Comparison(comparator=comparator, attribute="foo", value=value)
|
||||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||||
|
|
||||||
assert expected == actual
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user