mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 18:33:40 +00:00
langchain[patch], community[patch]: Fixes in the Ontotext GraphDB Graph and QA Chain (#17239)
- **Description:** Fixes in the Ontotext GraphDB Graph and QA Chain related to the error handling in case of invalid SPARQL queries, for which `prepareQuery` doesn't throw an exception, but the server returns 400 and the query is indeed invalid - **Issue:** N/A - **Dependencies:** N/A - **Twitter handle:** @OntotextGraphDB
This commit is contained in:
parent
b88329e9a5
commit
9bb5157a3d
@ -69,7 +69,7 @@
|
|||||||
"pip install openai==1.6.1\n",
|
"pip install openai==1.6.1\n",
|
||||||
"pip install rdflib==7.0.0\n",
|
"pip install rdflib==7.0.0\n",
|
||||||
"pip install langchain-openai==0.0.2\n",
|
"pip install langchain-openai==0.0.2\n",
|
||||||
"pip install langchain\n",
|
"pip install langchain>=0.1.5\n",
|
||||||
"```\n",
|
"```\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Run Jupyter with\n",
|
"Run Jupyter with\n",
|
||||||
|
@ -204,11 +204,7 @@ class OntotextGraphDBGraph:
|
|||||||
"""
|
"""
|
||||||
Query the graph.
|
Query the graph.
|
||||||
"""
|
"""
|
||||||
from rdflib.exceptions import ParserError
|
|
||||||
from rdflib.query import ResultRow
|
from rdflib.query import ResultRow
|
||||||
|
|
||||||
try:
|
res = self.graph.query(query)
|
||||||
res = self.graph.query(query)
|
|
||||||
except ParserError as e:
|
|
||||||
raise ValueError(f"Generated SPARQL statement is invalid\n{e}")
|
|
||||||
return [r for r in res if isinstance(r, ResultRow)]
|
return [r for r in res if isinstance(r, ResultRow)]
|
||||||
|
@ -10,7 +10,7 @@ cd libs/community/tests/integration_tests/graphs/docker-compose-ontotext-graphdb
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def test_query() -> None:
|
def test_query_method_with_valid_query() -> None:
|
||||||
graph = OntotextGraphDBGraph(
|
graph = OntotextGraphDBGraph(
|
||||||
query_endpoint="http://localhost:7200/repositories/langchain",
|
query_endpoint="http://localhost:7200/repositories/langchain",
|
||||||
query_ontology="CONSTRUCT {?s ?p ?o}"
|
query_ontology="CONSTRUCT {?s ?p ?o}"
|
||||||
@ -31,6 +31,36 @@ def test_query() -> None:
|
|||||||
assert str(query_results[0][0]) == "yellow"
|
assert str(query_results[0][0]) == "yellow"
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_method_with_invalid_query() -> None:
|
||||||
|
graph = OntotextGraphDBGraph(
|
||||||
|
query_endpoint="http://localhost:7200/repositories/langchain",
|
||||||
|
query_ontology="CONSTRUCT {?s ?p ?o}"
|
||||||
|
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as e:
|
||||||
|
graph.query(
|
||||||
|
"PREFIX : <https://swapi.co/vocabulary/> "
|
||||||
|
"PREFIX owl: <http://www.w3.org/2002/07/owl#> "
|
||||||
|
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
|
||||||
|
"PREFIX xsd: <http://www.w3.org/2001/XMLSchema#> "
|
||||||
|
"SELECT ?character (MAX(?lifespan) AS ?maxLifespan) "
|
||||||
|
"WHERE {"
|
||||||
|
" ?species a :Species ;"
|
||||||
|
" :character ?character ;"
|
||||||
|
" :averageLifespan ?lifespan ."
|
||||||
|
" FILTER(xsd:integer(?lifespan))"
|
||||||
|
"} "
|
||||||
|
"ORDER BY DESC(?maxLifespan) "
|
||||||
|
"LIMIT 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
str(e.value)
|
||||||
|
== "You did something wrong formulating either the URI or your SPARQL query"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_schema_with_query() -> None:
|
def test_get_schema_with_query() -> None:
|
||||||
graph = OntotextGraphDBGraph(
|
graph = OntotextGraphDBGraph(
|
||||||
query_endpoint="http://localhost:7200/repositories/langchain",
|
query_endpoint="http://localhost:7200/repositories/langchain",
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
"""Question answering over a graph."""
|
"""Question answering over a graph."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import rdflib
|
||||||
|
|
||||||
from langchain_community.graphs import OntotextGraphDBGraph
|
from langchain_community.graphs import OntotextGraphDBGraph
|
||||||
from langchain_core.callbacks.manager import CallbackManager
|
from langchain_core.callbacks.manager import CallbackManager
|
||||||
@ -97,10 +100,10 @@ class OntotextGraphDBQAChain(Chain):
|
|||||||
self.sparql_generation_chain.output_key
|
self.sparql_generation_chain.output_key
|
||||||
]
|
]
|
||||||
|
|
||||||
generated_sparql = self._get_valid_sparql_query(
|
generated_sparql = self._get_prepared_sparql_query(
|
||||||
_run_manager, callbacks, generated_sparql, ontology_schema
|
_run_manager, callbacks, generated_sparql, ontology_schema
|
||||||
)
|
)
|
||||||
query_results = self.graph.query(generated_sparql)
|
query_results = self._execute_query(generated_sparql)
|
||||||
|
|
||||||
qa_chain_result = self.qa_chain.invoke(
|
qa_chain_result = self.qa_chain.invoke(
|
||||||
{"prompt": prompt, "context": query_results}, callbacks=callbacks
|
{"prompt": prompt, "context": query_results}, callbacks=callbacks
|
||||||
@ -108,7 +111,7 @@ class OntotextGraphDBQAChain(Chain):
|
|||||||
result = qa_chain_result[self.qa_chain.output_key]
|
result = qa_chain_result[self.qa_chain.output_key]
|
||||||
return {self.output_key: result}
|
return {self.output_key: result}
|
||||||
|
|
||||||
def _get_valid_sparql_query(
|
def _get_prepared_sparql_query(
|
||||||
self,
|
self,
|
||||||
_run_manager: CallbackManagerForChainRun,
|
_run_manager: CallbackManagerForChainRun,
|
||||||
callbacks: CallbackManager,
|
callbacks: CallbackManager,
|
||||||
@ -153,10 +156,10 @@ class OntotextGraphDBQAChain(Chain):
|
|||||||
from rdflib.plugins.sparql import prepareQuery
|
from rdflib.plugins.sparql import prepareQuery
|
||||||
|
|
||||||
prepareQuery(generated_sparql)
|
prepareQuery(generated_sparql)
|
||||||
self._log_valid_sparql_query(_run_manager, generated_sparql)
|
self._log_prepared_sparql_query(_run_manager, generated_sparql)
|
||||||
return generated_sparql
|
return generated_sparql
|
||||||
|
|
||||||
def _log_valid_sparql_query(
|
def _log_prepared_sparql_query(
|
||||||
self, _run_manager: CallbackManagerForChainRun, generated_query: str
|
self, _run_manager: CallbackManagerForChainRun, generated_query: str
|
||||||
) -> None:
|
) -> None:
|
||||||
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
||||||
@ -180,3 +183,9 @@ class OntotextGraphDBQAChain(Chain):
|
|||||||
_run_manager.on_text(
|
_run_manager.on_text(
|
||||||
error_message, color="red", end="\n\n", verbose=self.verbose
|
error_message, color="red", end="\n\n", verbose=self.verbose
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _execute_query(self, query: str) -> List[rdflib.query.ResultRow]:
|
||||||
|
try:
|
||||||
|
return self.graph.query(query)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Failed to execute the generated SPARQL query.")
|
||||||
|
@ -165,6 +165,65 @@ def test_valid_sparql_after_first_retry(max_fix_retries: int) -> None:
|
|||||||
assert result == {chain.output_key: answer, chain.input_key: question}
|
assert result == {chain.output_key: answer, chain.input_key: question}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("langchain_openai", "rdflib")
|
||||||
|
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
|
||||||
|
def test_invalid_sparql_server_response_400(max_fix_retries: int) -> None:
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
question = "Who is the oldest character?"
|
||||||
|
generated_invalid_sparql = (
|
||||||
|
"PREFIX : <https://swapi.co/vocabulary/> "
|
||||||
|
"PREFIX owl: <http://www.w3.org/2002/07/owl#> "
|
||||||
|
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
|
||||||
|
"PREFIX xsd: <http://www.w3.org/2001/XMLSchema#> "
|
||||||
|
"SELECT ?character (MAX(?lifespan) AS ?maxLifespan) "
|
||||||
|
"WHERE {"
|
||||||
|
" ?species a :Species ;"
|
||||||
|
" :character ?character ;"
|
||||||
|
" :averageLifespan ?lifespan ."
|
||||||
|
" FILTER(xsd:integer(?lifespan))"
|
||||||
|
"} "
|
||||||
|
"ORDER BY DESC(?maxLifespan) "
|
||||||
|
"LIMIT 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = OntotextGraphDBGraph(
|
||||||
|
query_endpoint="http://localhost:7200/repositories/starwars",
|
||||||
|
query_ontology="CONSTRUCT {?s ?p ?o} "
|
||||||
|
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
|
||||||
|
)
|
||||||
|
chain = OntotextGraphDBQAChain.from_llm(
|
||||||
|
Mock(ChatOpenAI),
|
||||||
|
graph=graph,
|
||||||
|
max_fix_retries=max_fix_retries,
|
||||||
|
)
|
||||||
|
chain.sparql_generation_chain = Mock(LLMChain)
|
||||||
|
chain.sparql_fix_chain = Mock(LLMChain)
|
||||||
|
chain.qa_chain = Mock(LLMChain)
|
||||||
|
|
||||||
|
chain.sparql_generation_chain.output_key = "text"
|
||||||
|
chain.sparql_generation_chain.invoke = MagicMock(
|
||||||
|
return_value={
|
||||||
|
"text": generated_invalid_sparql,
|
||||||
|
"prompt": question,
|
||||||
|
"schema": "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
chain.sparql_fix_chain.output_key = "text"
|
||||||
|
chain.sparql_fix_chain.invoke = MagicMock()
|
||||||
|
chain.qa_chain.output_key = "text"
|
||||||
|
chain.qa_chain.invoke = MagicMock()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as e:
|
||||||
|
chain.invoke({chain.input_key: question})
|
||||||
|
|
||||||
|
assert str(e.value) == "Failed to execute the generated SPARQL query."
|
||||||
|
|
||||||
|
assert chain.sparql_generation_chain.invoke.call_count == 1
|
||||||
|
assert chain.sparql_fix_chain.invoke.call_count == 0
|
||||||
|
assert chain.qa_chain.invoke.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("langchain_openai", "rdflib")
|
@pytest.mark.requires("langchain_openai", "rdflib")
|
||||||
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
|
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
|
||||||
def test_invalid_sparql_after_all_retries(max_fix_retries: int) -> None:
|
def test_invalid_sparql_after_all_retries(max_fix_retries: int) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user