langchain[patch]: return formatted SPARQL query on demand (#11263)

- **Description:** Added the `return_sparql_query` feature to the
`GraphSparqlQAChain` class, allowing users to get the formatted SPARQL
query along with the chain's result.
  - **Issue:** NA
  - **Dependencies:** None

Note: I've ensured that the PR passes linting and testing by running
make format, make lint, and make test locally.

I have added a test for the integration (which relies on network access)
and I have added an example to the notebook showing its use.
This commit is contained in:
Reid Falconer
2024-02-23 02:03:26 +01:00
committed by GitHub
parent b15fccbb99
commit 0534ba5a7d
3 changed files with 137 additions and 11 deletions

View File

@@ -41,15 +41,25 @@ class GraphSparqlQAChain(Chain):
sparql_generation_update_chain: LLMChain
sparql_intent_chain: LLMChain
qa_chain: LLMChain
return_sparql_query: bool = False
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
sparql_query_key: str = "sparql_query" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@@ -135,4 +145,8 @@ class GraphSparqlQAChain(Chain):
res = "Successfully inserted triples into the graph."
else:
raise ValueError("Unsupported SPARQL query type.")
return {self.output_key: res}
chain_result: Dict[str, Any] = {self.output_key: res}
if self.return_sparql_query:
chain_result[self.sparql_query_key] = generated_sparql
return chain_result

View File

@@ -78,3 +78,29 @@ def test_sparql_insert() -> None:
os.remove(_local_copy)
except OSError:
pass
def test_sparql_select_return_query() -> None:
"""
Test for generating and executing simple SPARQL SELECT query
and returning the generated SPARQL query.
"""
berners_lee_card = "http://www.w3.org/People/Berners-Lee/card"
graph = RdfGraph(
source_file=berners_lee_card,
standard="rdf",
)
chain = GraphSparqlQAChain.from_llm(
OpenAI(temperature=0), graph=graph, return_sparql_query=True
)
output = chain("What is Tim Berners-Lee's work homepage?")
# Verify the expected answer
expected_output = (
" The work homepage of Tim Berners-Lee is "
"http://www.w3.org/People/Berners-Lee/."
)
assert output["result"] == expected_output
assert "sparql_query" in output