mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-22 19:09:57 +00:00
community[minor], langchian[minor]: Add Neptune Rdf graph and chain (#16650)
**Description**: This PR adds a chain for Amazon Neptune graph database RDF format. It complements the existing Neptune Cypher chain. The PR also includes a Neptune RDF graph class to connect to, introspect, and query a Neptune RDF graph database from the chain. A sample notebook is provided under docs that demonstrates the overall effect: invoking the chain to make natural language queries against Neptune using an LLM. **Issue**: This is a new feature **Dependencies**: The RDF graph class depends on the AWS boto3 library if using IAM authentication to connect to the Neptune database. --------- Co-authored-by: Piyush Jain <piyushjain@duck.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -41,6 +41,7 @@ from langchain.chains.graph_qa.hugegraph import HugeGraphQAChain
|
||||
from langchain.chains.graph_qa.kuzu import KuzuQAChain
|
||||
from langchain.chains.graph_qa.nebulagraph import NebulaGraphQAChain
|
||||
from langchain.chains.graph_qa.neptune_cypher import NeptuneOpenCypherQAChain
|
||||
from langchain.chains.graph_qa.neptune_sparql import NeptuneSparqlQAChain
|
||||
from langchain.chains.graph_qa.ontotext_graphdb import OntotextGraphDBQAChain
|
||||
from langchain.chains.graph_qa.sparql import GraphSparqlQAChain
|
||||
from langchain.chains.history_aware_retriever import create_history_aware_retriever
|
||||
@@ -116,6 +117,7 @@ __all__ = [
|
||||
"NatBotChain",
|
||||
"NebulaGraphQAChain",
|
||||
"NeptuneOpenCypherQAChain",
|
||||
"NeptuneSparqlQAChain",
|
||||
"OpenAIModerationChain",
|
||||
"OpenAPIEndpointChain",
|
||||
"QAGenerationChain",
|
||||
|
196
libs/langchain/langchain/chains/graph_qa/neptune_sparql.py
Normal file
196
libs/langchain/langchain/chains/graph_qa/neptune_sparql.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Question answering over an RDF or OWL graph using SPARQL.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_community.graphs import NeptuneRdfGraph
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import SPARQL_QA_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
SPARQL_GENERATION_TEMPLATE = """
|
||||
Task: Generate a SPARQL SELECT statement for querying a graph database.
|
||||
For instance, to find all email addresses of John Doe, the following
|
||||
query in backticks would be suitable:
|
||||
```
|
||||
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
||||
SELECT ?email
|
||||
WHERE {{
|
||||
?person foaf:name "John Doe" .
|
||||
?person foaf:mbox ?email .
|
||||
}}
|
||||
```
|
||||
Instructions:
|
||||
Use only the node types and properties provided in the schema.
|
||||
Do not use any node types and properties that are not explicitly provided.
|
||||
Include all necessary prefixes.
|
||||
|
||||
Examples:
|
||||
|
||||
Schema:
|
||||
{schema}
|
||||
Note: Be as concise as possible.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that ask for anything else than
|
||||
for you to construct a SPARQL query.
|
||||
Do not include any text except the SPARQL query generated.
|
||||
|
||||
The question is:
|
||||
{prompt}"""
|
||||
|
||||
SPARQL_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
|
||||
def extract_sparql(query: str) -> str:
|
||||
query = query.strip()
|
||||
querytoks = query.split("```")
|
||||
if len(querytoks) == 3:
|
||||
query = querytoks[1]
|
||||
|
||||
if query.startswith("sparql"):
|
||||
query = query[6:]
|
||||
elif query.startswith("<sparql>") and query.endswith("</sparql>"):
|
||||
query = query[8:-9]
|
||||
return query
|
||||
|
||||
|
||||
class NeptuneSparqlQAChain(Chain):
|
||||
"""Chain for question-answering against a Neptune graph
|
||||
by generating SPARQL statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain = NeptuneSparqlQAChain.from_llm(
|
||||
llm=llm,
|
||||
graph=graph
|
||||
)
|
||||
response = chain.invoke(query)
|
||||
"""
|
||||
|
||||
graph: NeptuneRdfGraph = Field(exclude=True)
|
||||
sparql_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 10
|
||||
return_intermediate_steps: bool = False
|
||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||
return_direct: bool = False
|
||||
"""Whether or not to return the result of querying the graph directly."""
|
||||
extra_instructions: Optional[str] = None
|
||||
"""Extra instructions by the appended to the query generation prompt."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT,
|
||||
sparql_prompt: BasePromptTemplate = SPARQL_GENERATION_PROMPT,
|
||||
examples: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> NeptuneSparqlQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
template_to_use = SPARQL_GENERATION_TEMPLATE
|
||||
if examples:
|
||||
template_to_use = template_to_use.replace(
|
||||
"Examples:", "Examples: " + examples
|
||||
)
|
||||
sparql_prompt = PromptTemplate(
|
||||
input_variables=["schema", "prompt"], template=template_to_use
|
||||
)
|
||||
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
sparql_generation_chain=sparql_generation_chain,
|
||||
examples=examples,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Generate SPARQL query, use it to retrieve a response from the gdb and answer
|
||||
the question.
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
prompt = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
generated_sparql = self.sparql_generation_chain.run(
|
||||
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
# Extract SPARQL
|
||||
generated_sparql = extract_sparql(generated_sparql)
|
||||
|
||||
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_sparql, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_sparql})
|
||||
|
||||
context = self.graph.query(generated_sparql)
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain(
|
||||
{"prompt": prompt, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
@@ -33,6 +33,7 @@ EXPECTED_ALL = [
|
||||
"NatBotChain",
|
||||
"NebulaGraphQAChain",
|
||||
"NeptuneOpenCypherQAChain",
|
||||
"NeptuneSparqlQAChain",
|
||||
"OpenAIModerationChain",
|
||||
"OpenAPIEndpointChain",
|
||||
"QAGenerationChain",
|
||||
|
Reference in New Issue
Block a user