mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
community[minor]: Self query retriever for HANA Cloud Vector Engine (#24494)
Description: - This PR adds a self query retriever implementation for SAP HANA Cloud Vector Engine. The retriever supports all operators except for contains. - Issue: N/A - Dependencies: no new dependencies added **Add tests and docs:** Added integration tests to: libs/community/tests/unit_tests/query_constructors/test_hanavector.py **Documentation for self query retriever:** /docs/integrations/retrievers/self_query/hanavector_self_query.ipynb --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
4f3b4fc7fe
commit
b65ac8d39c
@ -0,0 +1,246 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# SAP HANA Cloud Vector Engine\n",
|
||||||
|
"\n",
|
||||||
|
"For more information on how to setup the SAP HANA vetor store, take a look at the [documentation](/docs/integrations/vectorstores/sap_hanavector.ipynb).\n",
|
||||||
|
"\n",
|
||||||
|
"We use the same setup here:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"# Use OPENAI_API_KEY env variable\n",
|
||||||
|
"# os.environ[\"OPENAI_API_KEY\"] = \"Your OpenAI API key\"\n",
|
||||||
|
"from hdbcli import dbapi\n",
|
||||||
|
"\n",
|
||||||
|
"# Use connection settings from the environment\n",
|
||||||
|
"connection = dbapi.connect(\n",
|
||||||
|
" address=os.environ.get(\"HANA_DB_ADDRESS\"),\n",
|
||||||
|
" port=os.environ.get(\"HANA_DB_PORT\"),\n",
|
||||||
|
" user=os.environ.get(\"HANA_DB_USER\"),\n",
|
||||||
|
" password=os.environ.get(\"HANA_DB_PASSWORD\"),\n",
|
||||||
|
" autocommit=True,\n",
|
||||||
|
" sslValidateCertificate=False,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"To be able to self query with good performance we create additional metadata fields\n",
|
||||||
|
"for our vectorstore table in HANA:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Create custom table with attribute\n",
|
||||||
|
"cur = connection.cursor()\n",
|
||||||
|
"cur.execute(\"DROP TABLE LANGCHAIN_DEMO_SELF_QUERY\", ignoreErrors=True)\n",
|
||||||
|
"cur.execute(\n",
|
||||||
|
" (\n",
|
||||||
|
" \"\"\"CREATE TABLE \"LANGCHAIN_DEMO_SELF_QUERY\" (\n",
|
||||||
|
" \"name\" NVARCHAR(100), \"is_active\" BOOLEAN, \"id\" INTEGER, \"height\" DOUBLE,\n",
|
||||||
|
" \"VEC_TEXT\" NCLOB, \n",
|
||||||
|
" \"VEC_META\" NCLOB, \n",
|
||||||
|
" \"VEC_VECTOR\" REAL_VECTOR\n",
|
||||||
|
" )\"\"\"\n",
|
||||||
|
" )\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Let's add some documents."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.vectorstores.hanavector import HanaDB\n",
|
||||||
|
"from langchain_core.documents import Document\n",
|
||||||
|
"from langchain_openai import OpenAIEmbeddings\n",
|
||||||
|
"\n",
|
||||||
|
"embeddings = OpenAIEmbeddings()\n",
|
||||||
|
"\n",
|
||||||
|
"# Prepare some test documents\n",
|
||||||
|
"docs = [\n",
|
||||||
|
" Document(\n",
|
||||||
|
" page_content=\"First\",\n",
|
||||||
|
" metadata={\"name\": \"adam\", \"is_active\": True, \"id\": 1, \"height\": 10.0},\n",
|
||||||
|
" ),\n",
|
||||||
|
" Document(\n",
|
||||||
|
" page_content=\"Second\",\n",
|
||||||
|
" metadata={\"name\": \"bob\", \"is_active\": False, \"id\": 2, \"height\": 5.7},\n",
|
||||||
|
" ),\n",
|
||||||
|
" Document(\n",
|
||||||
|
" page_content=\"Third\",\n",
|
||||||
|
" metadata={\"name\": \"jane\", \"is_active\": True, \"id\": 3, \"height\": 2.4},\n",
|
||||||
|
" ),\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"db = HanaDB(\n",
|
||||||
|
" connection=connection,\n",
|
||||||
|
" embedding=embeddings,\n",
|
||||||
|
" table_name=\"LANGCHAIN_DEMO_SELF_QUERY\",\n",
|
||||||
|
" specific_metadata_columns=[\"name\", \"is_active\", \"id\", \"height\"],\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Delete already existing documents from the table\n",
|
||||||
|
"db.delete(filter={})\n",
|
||||||
|
"db.add_documents(docs)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Self querying\n",
|
||||||
|
"\n",
|
||||||
|
"Now for the main act: here is how to construct a SelfQueryRetriever for HANA vectorstore:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains.query_constructor.base import AttributeInfo\n",
|
||||||
|
"from langchain.retrievers.self_query.base import SelfQueryRetriever\n",
|
||||||
|
"from langchain_community.query_constructors.hanavector import HanaTranslator\n",
|
||||||
|
"from langchain_openai import ChatOpenAI\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\")\n",
|
||||||
|
"\n",
|
||||||
|
"metadata_field_info = [\n",
|
||||||
|
" AttributeInfo(\n",
|
||||||
|
" name=\"name\",\n",
|
||||||
|
" description=\"The name of the person\",\n",
|
||||||
|
" type=\"string\",\n",
|
||||||
|
" ),\n",
|
||||||
|
" AttributeInfo(\n",
|
||||||
|
" name=\"is_active\",\n",
|
||||||
|
" description=\"Whether the person is active\",\n",
|
||||||
|
" type=\"boolean\",\n",
|
||||||
|
" ),\n",
|
||||||
|
" AttributeInfo(\n",
|
||||||
|
" name=\"id\",\n",
|
||||||
|
" description=\"The ID of the person\",\n",
|
||||||
|
" type=\"integer\",\n",
|
||||||
|
" ),\n",
|
||||||
|
" AttributeInfo(\n",
|
||||||
|
" name=\"height\",\n",
|
||||||
|
" description=\"The height of the person\",\n",
|
||||||
|
" type=\"float\",\n",
|
||||||
|
" ),\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"document_content_description = \"A collection of persons\"\n",
|
||||||
|
"\n",
|
||||||
|
"hana_translator = HanaTranslator()\n",
|
||||||
|
"\n",
|
||||||
|
"retriever = SelfQueryRetriever.from_llm(\n",
|
||||||
|
" llm,\n",
|
||||||
|
" db,\n",
|
||||||
|
" document_content_description,\n",
|
||||||
|
" metadata_field_info,\n",
|
||||||
|
" structured_query_translator=hana_translator,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Let's use this retriever to prepare a (self) query for a person:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query_prompt = \"Which person is not active?\"\n",
|
||||||
|
"\n",
|
||||||
|
"docs = retriever.invoke(input=query_prompt)\n",
|
||||||
|
"for doc in docs:\n",
|
||||||
|
" print(\"-\" * 80)\n",
|
||||||
|
" print(doc.page_content, \" \", doc.metadata)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"We can also take a look at how the query is being constructed:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains.query_constructor.base import (\n",
|
||||||
|
" StructuredQueryOutputParser,\n",
|
||||||
|
" get_query_constructor_prompt,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = get_query_constructor_prompt(\n",
|
||||||
|
" document_content_description,\n",
|
||||||
|
" metadata_field_info,\n",
|
||||||
|
")\n",
|
||||||
|
"output_parser = StructuredQueryOutputParser.from_components()\n",
|
||||||
|
"query_constructor = prompt | llm | output_parser\n",
|
||||||
|
"\n",
|
||||||
|
"sq = query_constructor.invoke(input=query_prompt)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Structured query: \", sq)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Translated for hana vector store: \", hana_translator.visit_structured_query(sq))"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -0,0 +1,57 @@
|
|||||||
|
# HANA Translator/query constructor
|
||||||
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
|
from langchain_core.structured_query import (
|
||||||
|
Comparator,
|
||||||
|
Comparison,
|
||||||
|
Operation,
|
||||||
|
Operator,
|
||||||
|
StructuredQuery,
|
||||||
|
Visitor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HanaTranslator(Visitor):
|
||||||
|
"""
|
||||||
|
Translate internal query language elements to valid filters params for
|
||||||
|
HANA vectorstore.
|
||||||
|
"""
|
||||||
|
|
||||||
|
allowed_operators = [Operator.AND, Operator.OR]
|
||||||
|
"""Subset of allowed logical operators."""
|
||||||
|
allowed_comparators = [
|
||||||
|
Comparator.EQ,
|
||||||
|
Comparator.NE,
|
||||||
|
Comparator.GT,
|
||||||
|
Comparator.LT,
|
||||||
|
Comparator.GTE,
|
||||||
|
Comparator.LTE,
|
||||||
|
Comparator.IN,
|
||||||
|
Comparator.NIN,
|
||||||
|
# Comparator.CONTAIN,
|
||||||
|
Comparator.LIKE,
|
||||||
|
]
|
||||||
|
|
||||||
|
def _format_func(self, func: Union[Operator, Comparator]) -> str:
|
||||||
|
self._validate_func(func)
|
||||||
|
return f"${func.value}"
|
||||||
|
|
||||||
|
def visit_operation(self, operation: Operation) -> Dict:
|
||||||
|
args = [arg.accept(self) for arg in operation.arguments]
|
||||||
|
return {self._format_func(operation.operator): args}
|
||||||
|
|
||||||
|
def visit_comparison(self, comparison: Comparison) -> Dict:
|
||||||
|
return {
|
||||||
|
comparison.attribute: {
|
||||||
|
self._format_func(comparison.comparator): comparison.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_structured_query(
|
||||||
|
self, structured_query: StructuredQuery
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
if structured_query.filter is None:
|
||||||
|
kwargs = {}
|
||||||
|
else:
|
||||||
|
kwargs = {"filter": structured_query.filter.accept(self)}
|
||||||
|
return structured_query.query, kwargs
|
@ -191,7 +191,8 @@ class HanaDB(VectorStore):
|
|||||||
if column_length is not None and column_length > 0:
|
if column_length is not None and column_length > 0:
|
||||||
if rows[0][1] != column_length:
|
if rows[0][1] != column_length:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
f"Column {column_name} has the wrong length: {rows[0][1]}"
|
f"Column {column_name} has the wrong length: {rows[0][1]} "
|
||||||
|
f"expected: {column_length}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise AttributeError(f"Column {column_name} does not exist")
|
raise AttributeError(f"Column {column_name} does not exist")
|
||||||
@ -529,10 +530,18 @@ class HanaDB(VectorStore):
|
|||||||
if special_op in COMPARISONS_TO_SQL:
|
if special_op in COMPARISONS_TO_SQL:
|
||||||
operator = COMPARISONS_TO_SQL[special_op]
|
operator = COMPARISONS_TO_SQL[special_op]
|
||||||
if isinstance(special_val, bool):
|
if isinstance(special_val, bool):
|
||||||
query_tuple.append("true" if filter_value else "false")
|
query_tuple.append("true" if special_val else "false")
|
||||||
elif isinstance(special_val, float):
|
elif isinstance(special_val, float):
|
||||||
sql_param = "CAST(? as float)"
|
sql_param = "CAST(? as float)"
|
||||||
query_tuple.append(special_val)
|
query_tuple.append(special_val)
|
||||||
|
elif (
|
||||||
|
isinstance(special_val, dict)
|
||||||
|
and "type" in special_val
|
||||||
|
and special_val["type"] == "date"
|
||||||
|
):
|
||||||
|
# Date type
|
||||||
|
sql_param = "CAST(? as DATE)"
|
||||||
|
query_tuple.append(special_val["date"])
|
||||||
else:
|
else:
|
||||||
query_tuple.append(special_val)
|
query_tuple.append(special_val)
|
||||||
# "$between"
|
# "$between"
|
||||||
|
@ -0,0 +1,84 @@
|
|||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import pytest as pytest
|
||||||
|
from langchain_core.structured_query import (
|
||||||
|
Comparator,
|
||||||
|
Comparison,
|
||||||
|
Operation,
|
||||||
|
Operator,
|
||||||
|
StructuredQuery,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_community.query_constructors.hanavector import HanaTranslator
|
||||||
|
|
||||||
|
DEFAULT_TRANSLATOR = HanaTranslator()
|
||||||
|
|
||||||
|
|
||||||
|
def test_visit_comparison() -> None:
|
||||||
|
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1)
|
||||||
|
expected = {"foo": {"$lt": 1}}
|
||||||
|
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_visit_operation() -> None:
|
||||||
|
op = Operation(
|
||||||
|
operator=Operator.AND,
|
||||||
|
arguments=[
|
||||||
|
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||||
|
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||||
|
Comparison(comparator=Comparator.GT, attribute="abc", value=2.0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
expected = {
|
||||||
|
"$and": [{"foo": {"$lt": 2}}, {"bar": {"$eq": "baz"}}, {"abc": {"$gt": 2.0}}]
|
||||||
|
}
|
||||||
|
actual = DEFAULT_TRANSLATOR.visit_operation(op)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_visit_structured_query() -> None:
|
||||||
|
query = "What is the capital of France?"
|
||||||
|
structured_query = StructuredQuery(
|
||||||
|
query=query,
|
||||||
|
filter=None,
|
||||||
|
)
|
||||||
|
expected: Tuple[str, Dict] = (query, {})
|
||||||
|
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1)
|
||||||
|
structured_query = StructuredQuery(
|
||||||
|
query=query,
|
||||||
|
filter=comp,
|
||||||
|
)
|
||||||
|
expected = (query, {"filter": {"foo": {"$lt": 1}}})
|
||||||
|
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
op = Operation(
|
||||||
|
operator=Operator.AND,
|
||||||
|
arguments=[
|
||||||
|
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||||
|
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||||
|
Comparison(comparator=Comparator.GT, attribute="abc", value=2.0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
structured_query = StructuredQuery(
|
||||||
|
query=query,
|
||||||
|
filter=op,
|
||||||
|
)
|
||||||
|
expected = (
|
||||||
|
query,
|
||||||
|
{
|
||||||
|
"filter": {
|
||||||
|
"$and": [
|
||||||
|
{"foo": {"$lt": 2}},
|
||||||
|
{"bar": {"$eq": "baz"}},
|
||||||
|
{"abc": {"$gt": 2.0}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||||
|
assert expected == actual
|
@ -177,6 +177,16 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
|
|||||||
if isinstance(vectorstore, PGVector):
|
if isinstance(vectorstore, PGVector):
|
||||||
return NewPGVectorTranslator()
|
return NewPGVectorTranslator()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Added in langchain-community==0.2.11
|
||||||
|
from langchain_community.query_constructors.hanavector import HanaTranslator
|
||||||
|
from langchain_community.vectorstores import HanaDB
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if isinstance(vectorstore, HanaDB):
|
||||||
|
return HanaTranslator()
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Self query retriever with Vector Store type {vectorstore.__class__}"
|
f"Self query retriever with Vector Store type {vectorstore.__class__}"
|
||||||
f" not supported."
|
f" not supported."
|
||||||
|
Loading…
Reference in New Issue
Block a user