Add support for Falkordb (ex-RedisGraph) (#9821)

Replace this entire comment with:
  - Description: Add support for Falkordb (ex-RedisGraph)
  - Tag maintainer: @hwchase17
  - Twitter handle: @g_korland
This commit is contained in:
Guy Korland
2023-08-30 00:22:33 +03:00
committed by GitHub
parent fbd792ac7c
commit 7cbe872af8
6 changed files with 400 additions and 0 deletions

View File

@@ -36,6 +36,7 @@ from langchain.chains.flare.base import FlareChain
from langchain.chains.graph_qa.arangodb import ArangoGraphQAChain
from langchain.chains.graph_qa.base import GraphQAChain
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
from langchain.chains.graph_qa.falkordb import FalkorDBQAChain
from langchain.chains.graph_qa.hugegraph import HugeGraphQAChain
from langchain.chains.graph_qa.kuzu import KuzuQAChain
from langchain.chains.graph_qa.nebulagraph import NebulaGraphQAChain
@@ -85,6 +86,7 @@ __all__ = [
"ConstitutionalChain",
"ConversationChain",
"ConversationalRetrievalChain",
"FalkorDBQAChain",
"FlareChain",
"GraphCypherQAChain",
"GraphQAChain",

View File

@@ -0,0 +1,141 @@
"""Question answering over a graph."""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs import FalkorDBGraph
from langchain.pydantic_v1 import Field
from langchain.schema import BasePromptTemplate
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_cypher(text: str) -> str:
"""
Extract Cypher code from a text.
Args:
text: Text to extract Cypher code from.
Returns:
Cypher code extracted from the text.
"""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else text
class FalkorDBQAChain(Chain):
"""Chain for question-answering against a graph by generating Cypher statements."""
graph: FalkorDBGraph = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
"""Number of results to return from the query"""
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@property
def _chain_type(self) -> str:
return "graph_cypher_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
**kwargs: Any,
) -> FalkorDBQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
return cls(
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Generate Cypher statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher)
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_cypher, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_cypher})
# Retrieve and limit the number of results
context = self.graph.query(generated_cypher)[: self.top_k]
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

View File

@@ -1,6 +1,7 @@
"""**Graphs** provide a natural language interface to graph databases."""
from langchain.graphs.arangodb_graph import ArangoGraph
from langchain.graphs.falkordb_graph import FalkorDBGraph
from langchain.graphs.hugegraph import HugeGraph
from langchain.graphs.kuzu_graph import KuzuGraph
from langchain.graphs.memgraph_graph import MemgraphGraph
@@ -20,4 +21,5 @@ __all__ = [
"HugeGraph",
"RdfGraph",
"ArangoGraph",
"FalkorDBGraph",
]

View File

@@ -0,0 +1,67 @@
from typing import Any, Dict, List
node_properties_query = """
MATCH (n)
UNWIND labels(n) as l
UNWIND keys(n) as p
RETURN {label:l, properties: collect(distinct p)} AS output
"""
rel_properties_query = """
MATCH ()-[r]->()
UNWIND keys(r) as p
RETURN {type:type(r), properties: collect(distinct p)} AS output
"""
rel_query = """
MATCH (n)-[r]->(m)
WITH labels(n)[0] AS src, labels(m)[0] AS dst, type(r) AS type
RETURN DISTINCT "(:" + src + ")-[:" + type + "]->(:" + dst + ")" AS output
"""
class FalkorDBGraph:
"""FalkorDB wrapper for graph operations."""
def __init__(
self, database: str, host: str = "localhost", port: int = 6379
) -> None:
"""Create a new FalkorDB graph wrapper instance."""
try:
import redis
from redis.commands.graph import Graph
except ImportError:
raise ImportError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
self._driver = redis.Redis(host=host, port=port)
self._graph = Graph(self._driver, database)
try:
self.refresh_schema()
except Exception as e:
raise ValueError(f"Could not refresh schema. Error: {e}")
@property
def get_schema(self) -> str:
"""Returns the schema of the FalkorDB database"""
return self.schema
def refresh_schema(self) -> None:
"""Refreshes the schema of the FalkorDB database"""
self.schema = (
f"Node properties: {node_properties_query}\n"
f"Relationships properties: {rel_properties_query}\n"
f"Relationships: {rel_query}\n"
)
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query FalkorDB database."""
try:
data = self._graph.query(query, params)
return data.result_set
except Exception as e:
raise ValueError("Generated Cypher Statement is not valid\n" f"{e}")

View File

@@ -0,0 +1,34 @@
import unittest
from typing import Any
from unittest.mock import MagicMock, patch
from langchain.graphs import FalkorDBGraph
class TestFalkorDB(unittest.TestCase):
def setUp(self) -> None:
self.host = "localhost"
self.graph = "test_falkordb"
self.port = 6379
@patch("redis.Redis")
def test_init(self, mock_client: Any) -> None:
mock_client.return_value = MagicMock()
FalkorDBGraph(database=self.graph, host=self.host, port=self.port)
@patch("redis.Redis")
def test_execute(self, mock_client: Any) -> None:
mock_client.return_value = MagicMock()
graph = FalkorDBGraph(database=self.graph, host=self.host, port=self.port)
query = "RETURN 1"
result = graph.query(query)
self.assertIsInstance(result, MagicMock)
@patch("redis.Redis")
def test_refresh_schema(self, mock_client: Any) -> None:
mock_client.return_value = MagicMock()
graph = FalkorDBGraph(database=self.graph, host=self.host, port=self.port)
graph.refresh_schema()
self.assertNotEqual(graph.get_schema, "")