mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 18:33:40 +00:00
langchain[patch], community[patch]: Fixes in the Ontotext GraphDB Graph and QA Chain (#17239)
- **Description:** Fixes in the Ontotext GraphDB Graph and QA Chain related to the error handling in case of invalid SPARQL queries, for which `prepareQuery` doesn't throw an exception, but the server returns 400 and the query is indeed invalid - **Issue:** N/A - **Dependencies:** N/A - **Twitter handle:** @OntotextGraphDB
This commit is contained in:
parent
b88329e9a5
commit
9bb5157a3d
@ -69,7 +69,7 @@
|
||||
"pip install openai==1.6.1\n",
|
||||
"pip install rdflib==7.0.0\n",
|
||||
"pip install langchain-openai==0.0.2\n",
|
||||
"pip install langchain\n",
|
||||
"pip install langchain>=0.1.5\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Run Jupyter with\n",
|
||||
|
@ -204,11 +204,7 @@ class OntotextGraphDBGraph:
|
||||
"""
|
||||
Query the graph.
|
||||
"""
|
||||
from rdflib.exceptions import ParserError
|
||||
from rdflib.query import ResultRow
|
||||
|
||||
try:
|
||||
res = self.graph.query(query)
|
||||
except ParserError as e:
|
||||
raise ValueError(f"Generated SPARQL statement is invalid\n{e}")
|
||||
return [r for r in res if isinstance(r, ResultRow)]
|
||||
|
@ -10,7 +10,7 @@ cd libs/community/tests/integration_tests/graphs/docker-compose-ontotext-graphdb
|
||||
"""
|
||||
|
||||
|
||||
def test_query() -> None:
|
||||
def test_query_method_with_valid_query() -> None:
|
||||
graph = OntotextGraphDBGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/langchain",
|
||||
query_ontology="CONSTRUCT {?s ?p ?o}"
|
||||
@ -31,6 +31,36 @@ def test_query() -> None:
|
||||
assert str(query_results[0][0]) == "yellow"
|
||||
|
||||
|
||||
def test_query_method_with_invalid_query() -> None:
|
||||
graph = OntotextGraphDBGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/langchain",
|
||||
query_ontology="CONSTRUCT {?s ?p ?o}"
|
||||
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
graph.query(
|
||||
"PREFIX : <https://swapi.co/vocabulary/> "
|
||||
"PREFIX owl: <http://www.w3.org/2002/07/owl#> "
|
||||
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
|
||||
"PREFIX xsd: <http://www.w3.org/2001/XMLSchema#> "
|
||||
"SELECT ?character (MAX(?lifespan) AS ?maxLifespan) "
|
||||
"WHERE {"
|
||||
" ?species a :Species ;"
|
||||
" :character ?character ;"
|
||||
" :averageLifespan ?lifespan ."
|
||||
" FILTER(xsd:integer(?lifespan))"
|
||||
"} "
|
||||
"ORDER BY DESC(?maxLifespan) "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
assert (
|
||||
str(e.value)
|
||||
== "You did something wrong formulating either the URI or your SPARQL query"
|
||||
)
|
||||
|
||||
|
||||
def test_get_schema_with_query() -> None:
|
||||
graph = OntotextGraphDBGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/langchain",
|
||||
|
@ -1,7 +1,10 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import rdflib
|
||||
|
||||
from langchain_community.graphs import OntotextGraphDBGraph
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
@ -97,10 +100,10 @@ class OntotextGraphDBQAChain(Chain):
|
||||
self.sparql_generation_chain.output_key
|
||||
]
|
||||
|
||||
generated_sparql = self._get_valid_sparql_query(
|
||||
generated_sparql = self._get_prepared_sparql_query(
|
||||
_run_manager, callbacks, generated_sparql, ontology_schema
|
||||
)
|
||||
query_results = self.graph.query(generated_sparql)
|
||||
query_results = self._execute_query(generated_sparql)
|
||||
|
||||
qa_chain_result = self.qa_chain.invoke(
|
||||
{"prompt": prompt, "context": query_results}, callbacks=callbacks
|
||||
@ -108,7 +111,7 @@ class OntotextGraphDBQAChain(Chain):
|
||||
result = qa_chain_result[self.qa_chain.output_key]
|
||||
return {self.output_key: result}
|
||||
|
||||
def _get_valid_sparql_query(
|
||||
def _get_prepared_sparql_query(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
callbacks: CallbackManager,
|
||||
@ -153,10 +156,10 @@ class OntotextGraphDBQAChain(Chain):
|
||||
from rdflib.plugins.sparql import prepareQuery
|
||||
|
||||
prepareQuery(generated_sparql)
|
||||
self._log_valid_sparql_query(_run_manager, generated_sparql)
|
||||
self._log_prepared_sparql_query(_run_manager, generated_sparql)
|
||||
return generated_sparql
|
||||
|
||||
def _log_valid_sparql_query(
|
||||
def _log_prepared_sparql_query(
|
||||
self, _run_manager: CallbackManagerForChainRun, generated_query: str
|
||||
) -> None:
|
||||
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
||||
@ -180,3 +183,9 @@ class OntotextGraphDBQAChain(Chain):
|
||||
_run_manager.on_text(
|
||||
error_message, color="red", end="\n\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
def _execute_query(self, query: str) -> List[rdflib.query.ResultRow]:
|
||||
try:
|
||||
return self.graph.query(query)
|
||||
except Exception:
|
||||
raise ValueError("Failed to execute the generated SPARQL query.")
|
||||
|
@ -165,6 +165,65 @@ def test_valid_sparql_after_first_retry(max_fix_retries: int) -> None:
|
||||
assert result == {chain.output_key: answer, chain.input_key: question}
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "rdflib")
|
||||
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
|
||||
def test_invalid_sparql_server_response_400(max_fix_retries: int) -> None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
question = "Who is the oldest character?"
|
||||
generated_invalid_sparql = (
|
||||
"PREFIX : <https://swapi.co/vocabulary/> "
|
||||
"PREFIX owl: <http://www.w3.org/2002/07/owl#> "
|
||||
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
|
||||
"PREFIX xsd: <http://www.w3.org/2001/XMLSchema#> "
|
||||
"SELECT ?character (MAX(?lifespan) AS ?maxLifespan) "
|
||||
"WHERE {"
|
||||
" ?species a :Species ;"
|
||||
" :character ?character ;"
|
||||
" :averageLifespan ?lifespan ."
|
||||
" FILTER(xsd:integer(?lifespan))"
|
||||
"} "
|
||||
"ORDER BY DESC(?maxLifespan) "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
graph = OntotextGraphDBGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/starwars",
|
||||
query_ontology="CONSTRUCT {?s ?p ?o} "
|
||||
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
|
||||
)
|
||||
chain = OntotextGraphDBQAChain.from_llm(
|
||||
Mock(ChatOpenAI),
|
||||
graph=graph,
|
||||
max_fix_retries=max_fix_retries,
|
||||
)
|
||||
chain.sparql_generation_chain = Mock(LLMChain)
|
||||
chain.sparql_fix_chain = Mock(LLMChain)
|
||||
chain.qa_chain = Mock(LLMChain)
|
||||
|
||||
chain.sparql_generation_chain.output_key = "text"
|
||||
chain.sparql_generation_chain.invoke = MagicMock(
|
||||
return_value={
|
||||
"text": generated_invalid_sparql,
|
||||
"prompt": question,
|
||||
"schema": "",
|
||||
}
|
||||
)
|
||||
chain.sparql_fix_chain.output_key = "text"
|
||||
chain.sparql_fix_chain.invoke = MagicMock()
|
||||
chain.qa_chain.output_key = "text"
|
||||
chain.qa_chain.invoke = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
chain.invoke({chain.input_key: question})
|
||||
|
||||
assert str(e.value) == "Failed to execute the generated SPARQL query."
|
||||
|
||||
assert chain.sparql_generation_chain.invoke.call_count == 1
|
||||
assert chain.sparql_fix_chain.invoke.call_count == 0
|
||||
assert chain.qa_chain.invoke.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "rdflib")
|
||||
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
|
||||
def test_invalid_sparql_after_all_retries(max_fix_retries: int) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user