mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
langchain[patch]: support more comparators in Milvus self-querying retriever (#16076)
- **Description:** Support IN and LIKE comparators in Milvus self-querying retriever, based on [Boolean Expression Rules](https://milvus.io/docs/boolean.md) - **Issue:** No - **Dependencies:** No - **Twitter handle:** No Signed-off-by: ChengZi <chen.zhang@zilliz.com>
This commit is contained in:
parent
9c2f1f07a0
commit
8597484195
@ -16,28 +16,36 @@ COMPARATOR_TO_BER = {
|
||||
Comparator.GTE: ">=",
|
||||
Comparator.LT: "<",
|
||||
Comparator.LTE: "<=",
|
||||
Comparator.IN: "in",
|
||||
Comparator.LIKE: "like",
|
||||
}
|
||||
|
||||
UNARY_OPERATORS = [Operator.NOT]
|
||||
|
||||
|
||||
def process_value(value: Union[int, float, str]) -> str:
|
||||
def process_value(value: Union[int, float, str], comparator: Comparator) -> str:
|
||||
"""Convert a value to a string and add double quotes if it is a string.
|
||||
|
||||
It required for comparators involving strings.
|
||||
|
||||
Args:
|
||||
value: The value to convert.
|
||||
comparator: The comparator.
|
||||
|
||||
Returns:
|
||||
The converted value as a string.
|
||||
"""
|
||||
#
|
||||
if isinstance(value, str):
|
||||
# If the value is already a string, add double quotes
|
||||
return f'"{value}"'
|
||||
if comparator is Comparator.LIKE:
|
||||
# If the comparator is LIKE, add a percent sign after it for prefix matching
|
||||
# and add double quotes
|
||||
return f'"{value}%"'
|
||||
else:
|
||||
# If the value is already a string, add double quotes
|
||||
return f'"{value}"'
|
||||
else:
|
||||
# If the valueis not a string, convert it to a string without double quotes
|
||||
# If the value is not a string, convert it to a string without double quotes
|
||||
return str(value)
|
||||
|
||||
|
||||
@ -54,6 +62,8 @@ class MilvusTranslator(Visitor):
|
||||
Comparator.GTE,
|
||||
Comparator.LT,
|
||||
Comparator.LTE,
|
||||
Comparator.IN,
|
||||
Comparator.LIKE,
|
||||
]
|
||||
|
||||
def _format_func(self, func: Union[Operator, Comparator]) -> str:
|
||||
@ -78,7 +88,7 @@ class MilvusTranslator(Visitor):
|
||||
|
||||
def visit_comparison(self, comparison: Comparison) -> str:
|
||||
comparator = self._format_func(comparison.comparator)
|
||||
processed_value = process_value(comparison.value)
|
||||
processed_value = process_value(comparison.value, comparison.comparator)
|
||||
attribute = comparison.attribute
|
||||
|
||||
return "( " + attribute + " " + comparator + " " + processed_value + " )"
|
||||
|
@ -1,4 +1,6 @@
|
||||
from typing import Dict, Tuple
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
@ -12,11 +14,22 @@ from langchain.retrievers.self_query.milvus import MilvusTranslator
|
||||
DEFAULT_TRANSLATOR = MilvusTranslator()
|
||||
|
||||
|
||||
def test_visit_comparison() -> None:
|
||||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=4)
|
||||
expected = "( foo < 4 )"
|
||||
@pytest.mark.parametrize(
|
||||
"triplet",
|
||||
[
|
||||
(Comparator.EQ, 2, "( foo == 2 )"),
|
||||
(Comparator.GT, 2, "( foo > 2 )"),
|
||||
(Comparator.GTE, 2, "( foo >= 2 )"),
|
||||
(Comparator.LT, 2, "( foo < 2 )"),
|
||||
(Comparator.LTE, 2, "( foo <= 2 )"),
|
||||
(Comparator.IN, ["bar", "abc"], "( foo in ['bar', 'abc'] )"),
|
||||
(Comparator.LIKE, "bar", '( foo like "bar%" )'),
|
||||
],
|
||||
)
|
||||
def test_visit_comparison(triplet: Tuple[Comparator, Any, str]) -> None:
|
||||
comparator, value, expected = triplet
|
||||
comp = Comparison(comparator=comparator, attribute="foo", value=value)
|
||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||
|
||||
assert expected == actual
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user