From db98c44f8f4ccee2d39b048bb3fa70d657142081 Mon Sep 17 00:00:00 2001 From: felixocker Date: Wed, 5 Jul 2023 19:00:16 +0200 Subject: [PATCH] Support for SPARQL (#7165) # [SPARQL](https://www.w3.org/TR/rdf-sparql-query/) for [LangChain](https://github.com/hwchase17/langchain) ## Description LangChain support for knowledge graphs relying on W3C standards using RDFlib: SPARQL/ RDF(S)/ OWL with special focus on RDF \ * Works with local files, files from the web, and SPARQL endpoints * Supports both SELECT and UPDATE queries * Includes both a Jupyter notebook with an example and integration tests ## Contribution compared to related PRs and discussions * [Wikibase agent](https://github.com/hwchase17/langchain/pull/2690) - uses SPARQL, but specifically for wikibase querying * [Cypher qa](https://github.com/hwchase17/langchain/pull/5078) - graph DB question answering for Neo4J via Cypher * [PR 6050](https://github.com/hwchase17/langchain/pull/6050) - tries something similar, but does not cover UPDATE queries and supports only RDF * Discussions on [w3c mailing list](mailto:semantic-web@w3.org) related to the combination of LLMs (specifically ChatGPT) and knowledge graphs ## Dependencies * [RDFlib](https://github.com/RDFLib/rdflib) ## Tag maintainer Graph database related to memory -> @hwchase17 --- .../chains/additional/graph_sparql_qa.ipynb | 300 ++++++++++++++++++ langchain/chains/__init__.py | 2 + langchain/chains/graph_qa/prompts.py | 87 +++++ langchain/chains/graph_qa/sparql.py | 127 ++++++++ langchain/graphs/__init__.py | 10 +- langchain/graphs/rdf_graph.py | 279 ++++++++++++++++ poetry.lock | 49 +-- pyproject.toml | 2 + .../chains/test_graph_database_sparql.py | 79 +++++ 9 files changed, 915 insertions(+), 20 deletions(-) create mode 100644 docs/extras/modules/chains/additional/graph_sparql_qa.ipynb create mode 100644 langchain/chains/graph_qa/sparql.py create mode 100644 langchain/graphs/rdf_graph.py create mode 100644 tests/integration_tests/chains/test_graph_database_sparql.py diff --git a/docs/extras/modules/chains/additional/graph_sparql_qa.ipynb b/docs/extras/modules/chains/additional/graph_sparql_qa.ipynb new file mode 100644 index 00000000000..78990b25323 --- /dev/null +++ b/docs/extras/modules/chains/additional/graph_sparql_qa.ipynb @@ -0,0 +1,300 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c94240f5", + "metadata": {}, + "source": [ + "# GraphSparqlQAChain\n", + "\n", + "Graph databases are an excellent choice for applications based on network-like models. To standardize the syntax and semantics of such graphs, the W3C recommends Semantic Web Technologies, cp. [Semantic Web](https://www.w3.org/standards/semanticweb/). [SPARQL](https://www.w3.org/TR/sparql11-query/) serves as a query language analogously to SQL or Cypher for these graphs. This notebook demonstrates the application of LLMs as a natural language interface to a graph database by generating SPARQL.\\\n", + "Disclaimer: To date, SPARQL query generation via LLMs is still a bit unstable. Be especially careful with UPDATE queries, which alter the graph." + ] + }, + { + "cell_type": "markdown", + "id": "dbc0ee68", + "metadata": {}, + "source": [ + "There are several sources you can run queries against, including files on the web, files you have available locally, SPARQL endpoints, e.g., [Wikidata](https://www.wikidata.org/wiki/Wikidata:Main_Page), and [triple stores](https://www.w3.org/wiki/LargeTripleStores)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "62812aad", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.chains import GraphSparqlQAChain\n", + "from langchain.graphs import RdfGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0928915d", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "graph = RdfGraph(\n", + " source_file=\"http://www.w3.org/People/Berners-Lee/card\",\n", + " standard=\"rdf\",\n", + " local_copy=\"test.ttl\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Note that providing a `local_file` is necessary for storing changes locally if the source is read-only." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "id": "58c1a8ea", + "metadata": {}, + "source": [ + "## Refresh graph schema information\n", + "If the schema of the database changes, you can refresh the schema information needed to generate SPARQL queries." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4e3de44f", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "graph.load_schema()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1fe76ccd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In the following, each IRI is followed by the local name and optionally its description in parentheses. \n", + "The RDF graph supports the following node types:\n", + " (PersonalProfileDocument, None), (RSAPublicKey, None), (Male, None), (Person, None), (Work, None)\n", + "The RDF graph supports the following relationships:\n", + " (seeAlso, None), (title, None), (mbox_sha1sum, None), (maker, None), (oidcIssuer, None), (publicHomePage, None), (openid, None), (storage, None), (name, None), (country, None), (type, None), (profileHighlightColor, None), (preferencesFile, None), (label, None), (modulus, None), (participant, None), (street2, None), (locality, None), (nick, None), (homepage, None), (license, None), (givenname, None), (street-address, None), (postal-code, None), (street, None), (lat, None), (primaryTopic, None), (fn, None), (location, None), (developer, None), (city, None), (region, None), (member, None), (long, None), (address, None), (family_name, None), (account, None), (workplaceHomepage, None), (title, None), (publicTypeIndex, None), (office, None), (homePage, None), (mbox, None), (preferredURI, None), (profileBackgroundColor, None), (owns, None), (based_near, None), (hasAddress, None), (img, None), (assistant, None), (title, None), (key, None), (inbox, None), (editableProfile, None), (postalCode, None), (weblog, None), (exponent, None), (avatar, None)\n", + "\n" + ] + } + ], + "source": [ + "graph.get_schema" + ] + }, + { + "cell_type": "markdown", + "id": "68a3c677", + "metadata": {}, + "source": [ + "## Querying the graph\n", + "\n", + "Now, you can use the graph SPARQL QA chain to ask questions about the graph." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7476ce98", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "chain = GraphSparqlQAChain.from_llm(\n", + " ChatOpenAI(temperature=0), graph=graph, verbose=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "ef8ee27b", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new GraphSparqlQAChain chain...\u001B[0m\n", + "Identified intent:\n", + "\u001B[32;1m\u001B[1;3mSELECT\u001B[0m\n", + "Generated SPARQL:\n", + "\u001B[32;1m\u001B[1;3mPREFIX foaf: \n", + "SELECT ?homepage\n", + "WHERE {\n", + " ?person foaf:name \"Tim Berners-Lee\" .\n", + " ?person foaf:workplaceHomepage ?homepage .\n", + "}\u001B[0m\n", + "Full Context:\n", + "\u001B[32;1m\u001B[1;3m[]\u001B[0m\n", + "\n", + "\u001B[1m> Finished chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\"Tim Berners-Lee's work homepage is http://www.w3.org/People/Berners-Lee/.\"" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"What is Tim Berners-Lee's work homepage?\")" + ] + }, + { + "cell_type": "markdown", + "id": "af4b3294", + "metadata": {}, + "source": [ + "## Updating the graph\n", + "\n", + "Analogously, you can update the graph, i.e., insert triples, using natural language." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "fdf38841", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new GraphSparqlQAChain chain...\u001B[0m\n", + "Identified intent:\n", + "\u001B[32;1m\u001B[1;3mUPDATE\u001B[0m\n", + "Generated SPARQL:\n", + "\u001B[32;1m\u001B[1;3mPREFIX foaf: \n", + "INSERT {\n", + " ?person foaf:workplaceHomepage .\n", + "}\n", + "WHERE {\n", + " ?person foaf:name \"Timothy Berners-Lee\" .\n", + "}\u001B[0m\n", + "\n", + "\u001B[1m> Finished chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Successfully inserted triples into the graph.'" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"Save that the person with the name 'Timothy Berners-Lee' has a work homepage at 'http://www.w3.org/foo/bar/'\")" + ] + }, + { + "cell_type": "markdown", + "id": "5e0f7fc1", + "metadata": {}, + "source": [ + "Let's verify the results:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f874171b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[(rdflib.term.URIRef('https://www.w3.org/'),),\n", + " (rdflib.term.URIRef('http://www.w3.org/foo/bar/'),)]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = (\n", + " \"\"\"PREFIX foaf: \\n\"\"\"\n", + " \"\"\"SELECT ?hp\\n\"\"\"\n", + " \"\"\"WHERE {\\n\"\"\"\n", + " \"\"\" ?person foaf:name \"Timothy Berners-Lee\" . \\n\"\"\"\n", + " \"\"\" ?person foaf:workplaceHomepage ?hp .\\n\"\"\"\n", + " \"\"\"}\"\"\"\n", + ")\n", + "graph.query(query)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lc", + "language": "python", + "name": "lc" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index 90e0b9225ad..086aab47488 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -19,6 +19,7 @@ from langchain.chains.graph_qa.cypher import GraphCypherQAChain from langchain.chains.graph_qa.hugegraph import HugeGraphQAChain from langchain.chains.graph_qa.kuzu import KuzuQAChain from langchain.chains.graph_qa.nebulagraph import NebulaGraphQAChain +from langchain.chains.graph_qa.sparql import GraphSparqlQAChain from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.base import LLMBashChain @@ -69,6 +70,7 @@ __all__ = [ "FlareChain", "GraphCypherQAChain", "GraphQAChain", + "GraphSparqlQAChain", "HypotheticalDocumentEmbedder", "KuzuQAChain", "HugeGraphQAChain", diff --git a/langchain/chains/graph_qa/prompts.py b/langchain/chains/graph_qa/prompts.py index c3b49b214d5..ca68983b083 100644 --- a/langchain/chains/graph_qa/prompts.py +++ b/langchain/chains/graph_qa/prompts.py @@ -109,3 +109,90 @@ Helpful Answer:""" CYPHER_QA_PROMPT = PromptTemplate( input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE ) + +SPARQL_INTENT_TEMPLATE = """Task: Identify the intent of a prompt and return the appropriate SPARQL query type. +You are an assistant that distinguishes different types of prompts and returns the corresponding SPARQL query types. +Consider only the following query types: +* SELECT: this query type corresponds to questions +* UPDATE: this query type corresponds to all requests for deleting, inserting, or changing triples +Note: Be as concise as possible. +Do not include any explanations or apologies in your responses. +Do not respond to any questions that ask for anything else than for you to identify a SPARQL query type. +Do not include any unnecessary whitespaces or any text except the query type, i.e., either return 'SELECT' or 'UPDATE'. + +The prompt is: +{prompt} +Helpful Answer:""" +SPARQL_INTENT_PROMPT = PromptTemplate( + input_variables=["prompt"], template=SPARQL_INTENT_TEMPLATE +) + +SPARQL_GENERATION_SELECT_TEMPLATE = """Task: Generate a SPARQL SELECT statement for querying a graph database. +For instance, to find all email addresses of John Doe, the following query in backticks would be suitable: +``` +PREFIX foaf: +SELECT ?email +WHERE {{ + ?person foaf:name "John Doe" . + ?person foaf:mbox ?email . +}} +``` +Instructions: +Use only the node types and properties provided in the schema. +Do not use any node types and properties that are not explicitly provided. +Include all necessary prefixes. +Schema: +{schema} +Note: Be as concise as possible. +Do not include any explanations or apologies in your responses. +Do not respond to any questions that ask for anything else than for you to construct a SPARQL query. +Do not include any text except the SPARQL query generated. + +The question is: +{prompt}""" +SPARQL_GENERATION_SELECT_PROMPT = PromptTemplate( + input_variables=["schema", "prompt"], template=SPARQL_GENERATION_SELECT_TEMPLATE +) + +SPARQL_GENERATION_UPDATE_TEMPLATE = """Task: Generate a SPARQL UPDATE statement for updating a graph database. +For instance, to add 'jane.doe@foo.bar' as a new email address for Jane Doe, the following query in backticks would be suitable: +``` +PREFIX foaf: +INSERT {{ + ?person foaf:mbox . +}} +WHERE {{ + ?person foaf:name "Jane Doe" . +}} +``` +Instructions: +Make the query as short as possible and avoid adding unnecessary triples. +Use only the node types and properties provided in the schema. +Do not use any node types and properties that are not explicitly provided. +Include all necessary prefixes. +Schema: +{schema} +Note: Be as concise as possible. +Do not include any explanations or apologies in your responses. +Do not respond to any questions that ask for anything else than for you to construct a SPARQL query. +Return only the generated SPARQL query, nothing else. + +The information to be inserted is: +{prompt}""" +SPARQL_GENERATION_UPDATE_PROMPT = PromptTemplate( + input_variables=["schema", "prompt"], template=SPARQL_GENERATION_UPDATE_TEMPLATE +) + +SPARQL_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query. +You are an assistant that creates well-written and human understandable answers. +The information part contains the information provided, which you can use to construct an answer. +The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it. +Make your response sound like the information is coming from an AI assistant, but don't add any information. +Information: +{context} + +Question: {prompt} +Helpful Answer:""" +SPARQL_QA_PROMPT = PromptTemplate( + input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE +) diff --git a/langchain/chains/graph_qa/sparql.py b/langchain/chains/graph_qa/sparql.py new file mode 100644 index 00000000000..5c1389bef26 --- /dev/null +++ b/langchain/chains/graph_qa/sparql.py @@ -0,0 +1,127 @@ +""" +Question answering over an RDF or OWL graph using SPARQL. +""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from pydantic import Field + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.chains.base import Chain +from langchain.chains.graph_qa.prompts import ( + SPARQL_GENERATION_SELECT_PROMPT, + SPARQL_GENERATION_UPDATE_PROMPT, + SPARQL_INTENT_PROMPT, + SPARQL_QA_PROMPT, +) +from langchain.chains.llm import LLMChain +from langchain.graphs.rdf_graph import RdfGraph +from langchain.prompts.base import BasePromptTemplate + + +class GraphSparqlQAChain(Chain): + """ + Chain for question-answering against an RDF or OWL graph by generating + SPARQL statements. + """ + + graph: RdfGraph = Field(exclude=True) + sparql_generation_select_chain: LLMChain + sparql_generation_update_chain: LLMChain + sparql_intent_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + @property + def input_keys(self) -> List[str]: + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT, + sparql_select_prompt: BasePromptTemplate = SPARQL_GENERATION_SELECT_PROMPT, + sparql_update_prompt: BasePromptTemplate = SPARQL_GENERATION_UPDATE_PROMPT, + sparql_intent_prompt: BasePromptTemplate = SPARQL_INTENT_PROMPT, + **kwargs: Any, + ) -> GraphSparqlQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + sparql_generation_select_chain = LLMChain(llm=llm, prompt=sparql_select_prompt) + sparql_generation_update_chain = LLMChain(llm=llm, prompt=sparql_update_prompt) + sparql_intent_chain = LLMChain(llm=llm, prompt=sparql_intent_prompt) + + return cls( + qa_chain=qa_chain, + sparql_generation_select_chain=sparql_generation_select_chain, + sparql_generation_update_chain=sparql_generation_update_chain, + sparql_intent_chain=sparql_intent_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """ + Generate SPARQL query, use it to retrieve a response from the gdb and answer + the question. + """ + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + prompt = inputs[self.input_key] + + _intent = self.sparql_intent_chain.run({"prompt": prompt}, callbacks=callbacks) + intent = _intent.strip() + + if intent == "SELECT": + sparql_generation_chain = self.sparql_generation_select_chain + elif intent == "UPDATE": + sparql_generation_chain = self.sparql_generation_update_chain + else: + raise ValueError( + "I am sorry, but this prompt seems to fit none of the currently " + "supported SPARQL query types, i.e., SELECT and UPDATE." + ) + + _run_manager.on_text("Identified intent:", end="\n", verbose=self.verbose) + _run_manager.on_text(intent, color="green", end="\n", verbose=self.verbose) + + generated_sparql = sparql_generation_chain.run( + {"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks + ) + + _run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_sparql, color="green", end="\n", verbose=self.verbose + ) + + if intent == "SELECT": + context = self.graph.query(generated_sparql) + + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + result = self.qa_chain( + {"prompt": prompt, "context": context}, + callbacks=callbacks, + ) + res = result[self.qa_chain.output_key] + elif intent == "UPDATE": + self.graph.update(generated_sparql) + res = "Successfully inserted triples into the graph." + else: + raise ValueError("Unsupported SPARQL query type.") + return {self.output_key: res} diff --git a/langchain/graphs/__init__.py b/langchain/graphs/__init__.py index 437fe4a9d63..4996623b0aa 100644 --- a/langchain/graphs/__init__.py +++ b/langchain/graphs/__init__.py @@ -4,5 +4,13 @@ from langchain.graphs.kuzu_graph import KuzuGraph from langchain.graphs.nebula_graph import NebulaGraph from langchain.graphs.neo4j_graph import Neo4jGraph from langchain.graphs.networkx_graph import NetworkxEntityGraph +from langchain.graphs.rdf_graph import RdfGraph -__all__ = ["NetworkxEntityGraph", "Neo4jGraph", "NebulaGraph", "KuzuGraph", "HugeGraph"] +__all__ = [ + "NetworkxEntityGraph", + "Neo4jGraph", + "NebulaGraph", + "KuzuGraph", + "HugeGraph", + "RdfGraph", +] diff --git a/langchain/graphs/rdf_graph.py b/langchain/graphs/rdf_graph.py new file mode 100644 index 00000000000..4efeb4d0b80 --- /dev/null +++ b/langchain/graphs/rdf_graph.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + List, + Optional, +) + +if TYPE_CHECKING: + import rdflib + +prefixes = { + "owl": """PREFIX owl: \n""", + "rdf": """PREFIX rdf: \n""", + "rdfs": """PREFIX rdfs: \n""", + "xsd": """PREFIX xsd: \n""", +} + +cls_query_rdf = prefixes["rdfs"] + ( + """SELECT DISTINCT ?cls ?com\n""" + """WHERE { \n""" + """ ?instance a ?cls . \n""" + """ OPTIONAL { ?cls rdfs:comment ?com } \n""" + """}""" +) + +cls_query_rdfs = prefixes["rdfs"] + ( + """SELECT DISTINCT ?cls ?com\n""" + """WHERE { \n""" + """ ?instance a/rdfs:subClassOf* ?cls . \n""" + """ OPTIONAL { ?cls rdfs:comment ?com } \n""" + """}""" +) + +cls_query_owl = cls_query_rdfs + +rel_query_rdf = prefixes["rdfs"] + ( + """SELECT DISTINCT ?rel ?com\n""" + """WHERE { \n""" + """ ?subj ?rel ?obj . \n""" + """ OPTIONAL { ?cls rdfs:comment ?com } \n""" + """}""" +) + +rel_query_rdfs = ( + prefixes["rdf"] + + prefixes["rdfs"] + + ( + """SELECT DISTINCT ?rel ?com\n""" + """WHERE { \n""" + """ ?rel a/rdfs:subPropertyOf* rdf:Property . \n""" + """ OPTIONAL { ?cls rdfs:comment ?com } \n""" + """}""" + ) +) + +op_query_owl = ( + prefixes["rdfs"] + + prefixes["owl"] + + ( + """SELECT DISTINCT ?op ?com\n""" + """WHERE { \n""" + """ ?op a/rdfs:subPropertyOf* owl:ObjectProperty . \n""" + """ OPTIONAL { ?cls rdfs:comment ?com } \n""" + """}""" + ) +) + +dp_query_owl = ( + prefixes["rdfs"] + + prefixes["owl"] + + ( + """SELECT DISTINCT ?dp ?com\n""" + """WHERE { \n""" + """ ?dp a/rdfs:subPropertyOf* owl:DatatypeProperty . \n""" + """ OPTIONAL { ?cls rdfs:comment ?com } \n""" + """}""" + ) +) + + +class RdfGraph: + """ + RDFlib wrapper for graph operations. + Modes: + * local: Local file - can be queried and changed + * online: Online file - can only be queried, changes can be stored locally + * store: Triple store - can be queried and changed if update_endpoint available + Together with a source file, the serialization should be specified. + """ + + def __init__( + self, + source_file: Optional[str] = None, + serialization: Optional[str] = "ttl", + query_endpoint: Optional[str] = None, + update_endpoint: Optional[str] = None, + standard: Optional[str] = "rdf", + local_copy: Optional[str] = None, + ) -> None: + """ + Set up the RDFlib graph + + :param source_file: either a path for a local file or a URL + :param serialization: serialization of the input + :param query_endpoint: SPARQL endpoint for queries, read access + :param update_endpoint: SPARQL endpoint for UPDATE queries, write access + :param standard: RDF, RDFS, or OWL + :param local_copy: new local copy for storing changes + """ + self.source_file = source_file + self.serialization = serialization + self.query_endpoint = query_endpoint + self.update_endpoint = update_endpoint + self.standard = standard + self.local_copy = local_copy + + try: + import rdflib + from rdflib.graph import DATASET_DEFAULT_GRAPH_ID as default + from rdflib.plugins.stores import sparqlstore + except ImportError: + raise ValueError( + "Could not import rdflib python package. " + "Please install it with `pip install rdflib`." + ) + if self.standard not in (supported_standards := ("rdf", "rdfs", "owl")): + raise ValueError( + f"Invalid standard. Supported standards are: {supported_standards}." + ) + + if ( + not source_file + and not query_endpoint + or source_file + and (query_endpoint or update_endpoint) + ): + raise ValueError( + "Could not unambiguously initialize the graph wrapper. " + "Specify either a file (local or online) via the source_file " + "or a triple store via the endpoints." + ) + + if source_file: + if source_file.startswith("http"): + self.mode = "online" + else: + self.mode = "local" + if self.local_copy is None: + self.local_copy = self.source_file + self.graph = rdflib.Graph() + self.graph.parse(source_file, format=self.serialization) + + if query_endpoint: + self.mode = "store" + if not update_endpoint: + self._store = sparqlstore.SPARQLStore() + self._store.open(query_endpoint) + else: + self._store = sparqlstore.SPARQLUpdateStore() + self._store.open((query_endpoint, update_endpoint)) + self.graph = rdflib.Graph(self._store, identifier=default) + + # Verify that the graph was loaded + if not len(self.graph): + raise AssertionError("The graph is empty.") + + # Set schema + self.schema = "" + self.load_schema() + + @property + def get_schema(self) -> str: + """ + Returns the schema of the graph database. + """ + return self.schema + + def query( + self, + query: str, + ) -> List[rdflib.query.ResultRow]: + """ + Query the graph. + """ + from rdflib.exceptions import ParserError + from rdflib.query import ResultRow + + try: + res = self.graph.query(query) + except ParserError as e: + raise ValueError("Generated SPARQL statement is invalid\n" f"{e}") + return [r for r in res if isinstance(r, ResultRow)] + + def update( + self, + query: str, + ) -> None: + """ + Update the graph. + """ + from rdflib.exceptions import ParserError + + try: + self.graph.update(query) + except ParserError as e: + raise ValueError("Generated SPARQL statement is invalid\n" f"{e}") + if self.local_copy: + self.graph.serialize( + destination=self.local_copy, format=self.local_copy.split(".")[-1] + ) + else: + raise ValueError("No target file specified for saving the updated file.") + + @staticmethod + def _get_local_name(iri: str) -> str: + if "#" in iri: + local_name = iri.split("#")[-1] + elif "/" in iri: + local_name = iri.split("/")[-1] + else: + raise ValueError(f"Unexpected IRI '{iri}', contains neither '#' nor '/'.") + return local_name + + def _res_to_str(self, res: rdflib.query.ResultRow, var: str) -> str: + return ( + "<" + + res[var] + + "> (" + + self._get_local_name(res[var]) + + ", " + + str(res["com"]) + + ")" + ) + + def load_schema(self) -> None: + """ + Load the graph schema information. + """ + + def _rdf_s_schema( + classes: List[rdflib.query.ResultRow], + relationships: List[rdflib.query.ResultRow], + ) -> str: + return ( + f"In the following, each IRI is followed by the local name and " + f"optionally its description in parentheses. \n" + f"The RDF graph supports the following node types:\n" + f'{", ".join([self._res_to_str(r, "cls") for r in classes])}\n' + f"The RDF graph supports the following relationships:\n" + f'{", ".join([self._res_to_str(r, "rel") for r in relationships])}\n' + ) + + if self.standard == "rdf": + clss = self.query(cls_query_rdf) + rels = self.query(rel_query_rdf) + self.schema = _rdf_s_schema(clss, rels) + elif self.standard == "rdfs": + clss = self.query(cls_query_rdfs) + rels = self.query(rel_query_rdfs) + self.schema = _rdf_s_schema(clss, rels) + elif self.standard == "owl": + clss = self.query(cls_query_owl) + ops = self.query(cls_query_owl) + dps = self.query(cls_query_owl) + self.schema = ( + f"In the following, each IRI is followed by the local name and " + f"optionally its description in parentheses. \n" + f"The OWL graph supports the following node types:\n" + f'{", ".join([self._res_to_str(r, "cls") for r in clss])}\n' + f"The OWL graph supports the following object properties, " + f"i.e., relationships between objects:\n" + f'{", ".join([self._res_to_str(r, "op") for r in ops])}\n' + f"The OWL graph supports the following data properties, " + f"i.e., relationships between objects and literals:\n" + f'{", ".join([self._res_to_str(r, "dp") for r in dps])}\n' + ) + else: + raise ValueError(f"Mode '{self.standard}' is currently not supported.") diff --git a/poetry.lock b/poetry.lock index e17e5e91cbf..22e95cfb7ba 100644 --- a/poetry.lock +++ b/poetry.lock @@ -642,12 +642,17 @@ category = "main" optional = true python-versions = ">=3.7" files = [ + {file = "awadb-0.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:871e2b10c79348d44522b8430af1ed1ad2632322c74abc20d8a3154de242da96"}, {file = "awadb-0.3.3-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:daebc108103c8cace41dfb3235fcfdda28ea48e6cd6548b6072f7ad49b64274b"}, {file = "awadb-0.3.3-cp311-cp311-macosx_10_13_universal2.whl", hash = "sha256:2bb3ca2f943448060b1bba4395dd99e2218d7f2149507a8fdfa7a3fd4cfe97ec"}, + {file = "awadb-0.3.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:83e92963cde54a4382b0c299939865ce12e145853637642bc8e6eb22bf689386"}, + {file = "awadb-0.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6f249b04f38840146a5c17ffcd0f4da1bb00a39b8882c96e042acf58045faca2"}, {file = "awadb-0.3.3-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7b99662af9f7b58e217661a70c295e40605900552bec6d8e9553d90dbf19c5c1"}, {file = "awadb-0.3.3-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:94be44e587f28fa26b2cade0b6f4c04689f50cb0c07183db5ee50e48fe2e9ae3"}, {file = "awadb-0.3.3-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:314929dc3a8d25c0f234a2b86c920543050f4eb298a6f68bd2c97c9fe3fb6224"}, + {file = "awadb-0.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:35d8f580973e137864d2e00edbc7369cd01cf72b673d60fe902c7b3f983c76e9"}, {file = "awadb-0.3.3-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:8bfccff1c7373899153427d93d96a97ae5371e8a6f09ff4dcbd28fb9f3f63ff4"}, + {file = "awadb-0.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c791bdd0646ec620c8b0fa026915780ebf78c16169cd9da81f54409553ec0114"}, {file = "awadb-0.3.3-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:810021a90b873f668d8ab63e2c2747b2b2835bf0ae25f4223b6c94f06faffea4"}, ] @@ -3653,23 +3658,6 @@ cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (>=1.0.0,<2.0.0)"] -[[package]] -name = "hugegraph-python" -version = "1.0.0.12" -description = "Python client for HugeGraph" -optional = true -python-versions = "*" -files = [ - {file = "hugegraph-python-1.0.0.12.tar.gz", hash = "sha256:06b2dded70c4f4570083f8b6e3a9edfebcf5ac4f07300727afad72389917ab85"}, - {file = "hugegraph_python-1.0.0.12-py3-none-any.whl", hash = "sha256:69fe20edbe1a392d16afc74df5c94b3b96bc02c848e9ab5b5f18c112a9bc3ebe"}, -] - -[package.dependencies] -decorator = "5.1.1" -Requests = "2.31.0" -setuptools = "67.6.1" -urllib3 = "2.0.3" - [[package]] name = "huggingface-hub" version = "0.15.1" @@ -4306,6 +4294,7 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, + {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -8924,6 +8913,28 @@ files = [ [package.extras] test = ["pytest (>=3.0)", "pytest-asyncio"] +[[package]] +name = "rdflib" +version = "6.3.2" +description = "RDFLib is a Python library for working with RDF, a simple yet powerful language for representing information." +category = "main" +optional = true +python-versions = ">=3.7,<4.0" +files = [ + {file = "rdflib-6.3.2-py3-none-any.whl", hash = "sha256:36b4e74a32aa1e4fa7b8719876fb192f19ecd45ff932ea5ebbd2e417a0247e63"}, + {file = "rdflib-6.3.2.tar.gz", hash = "sha256:72af591ff704f4caacea7ecc0c5a9056b8553e0489dd4f35a9bc52dbd41522e0"}, +] + +[package.dependencies] +isodate = ">=0.6.0,<0.7.0" +pyparsing = ">=2.1.0,<4" + +[package.extras] +berkeleydb = ["berkeleydb (>=18.1.0,<19.0.0)"] +html = ["html5lib (>=1.0,<2.0)"] +lxml = ["lxml (>=4.3.0,<5.0.0)"] +networkx = ["networkx (>=2.0.0,<3.0.0)"] + [[package]] name = "redis" version = "4.5.5" @@ -12371,7 +12382,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "octoai-sdk", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "spacy", "steamship", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] +all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "octoai-sdk", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "rdflib", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "spacy", "steamship", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "azure-search-documents", "openai"] clarifai = ["clarifai"] cohere = ["cohere"] @@ -12387,4 +12398,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "1ebf19951081f82f4af129d9e2ea12c40fdc519bf49dffa0cfa85eb28938710c" +content-hash = "e2450d84b1a0747c45b015e1071ed37265269c325362779cd8bd3b9caa94a9c9" diff --git a/pyproject.toml b/pyproject.toml index 764fb7b91cf..3560d6b8d11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,6 +114,7 @@ openllm = {version = ">=0.1.19", optional = true} streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"} psychicapi = {version = "^0.8.0", optional = true} cassio = {version = "^0.0.7", optional = true} +rdflib = {version = "^6.3.2", optional = true} [tool.poetry.group.docs.dependencies] autodoc_pydantic = "^1.8.0" @@ -310,6 +311,7 @@ all = [ "awadb", "esprima", "octoai-sdk", + "rdflib", ] # An extra used to be able to add extended testing. diff --git a/tests/integration_tests/chains/test_graph_database_sparql.py b/tests/integration_tests/chains/test_graph_database_sparql.py new file mode 100644 index 00000000000..5ac33ba4918 --- /dev/null +++ b/tests/integration_tests/chains/test_graph_database_sparql.py @@ -0,0 +1,79 @@ +"""Test RDF/ SPARQL Graph Database Chain.""" +import os + +from langchain.chains.graph_qa.sparql import GraphSparqlQAChain +from langchain.graphs import RdfGraph +from langchain.llms.openai import OpenAI + + +def test_connect_file_rdf() -> None: + """ + Test loading online resource. + """ + berners_lee_card = "http://www.w3.org/People/Berners-Lee/card" + + graph = RdfGraph( + source_file=berners_lee_card, + standard="rdf", + ) + + query = """SELECT ?s ?p ?o\n""" """WHERE { ?s ?p ?o }""" + + output = graph.query(query) + assert len(output) == 86 + + +def test_sparql_select() -> None: + """ + Test for generating and executing simple SPARQL SELECT query. + """ + berners_lee_card = "http://www.w3.org/People/Berners-Lee/card" + + graph = RdfGraph( + source_file=berners_lee_card, + standard="rdf", + ) + + chain = GraphSparqlQAChain.from_llm(OpenAI(temperature=0), graph=graph) + output = chain.run("What is Tim Berners-Lee's work homepage?") + expected_output = ( + " The work homepage of Tim Berners-Lee is " + "http://www.w3.org/People/Berners-Lee/." + ) + assert output == expected_output + + +def test_sparql_insert() -> None: + """ + Test for generating and executing simple SPARQL INSERT query. + """ + berners_lee_card = "http://www.w3.org/People/Berners-Lee/card" + _local_copy = "test.ttl" + + graph = RdfGraph( + source_file=berners_lee_card, + standard="rdf", + local_copy=_local_copy, + ) + + chain = GraphSparqlQAChain.from_llm(OpenAI(temperature=0), graph=graph) + chain.run( + "Save that the person with the name 'Timothy Berners-Lee' " + "has a work homepage at 'http://www.w3.org/foo/bar/'" + ) + query = ( + """PREFIX foaf: \n""" + """SELECT ?hp\n""" + """WHERE {\n""" + """ ?person foaf:name "Timothy Berners-Lee" . \n""" + """ ?person foaf:workplaceHomepage ?hp .\n""" + """}""" + ) + output = graph.query(query) + assert len(output) == 2 + + # clean up + try: + os.remove(_local_copy) + except OSError: + pass