mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 15:04:13 +00:00
Add support for Falkordb (ex-RedisGraph) (#9821)
Replace this entire comment with: - Description: Add support for Falkordb (ex-RedisGraph) - Tag maintainer: @hwchase17 - Twitter handle: @g_korland
This commit is contained in:
@@ -36,6 +36,7 @@ from langchain.chains.flare.base import FlareChain
|
||||
from langchain.chains.graph_qa.arangodb import ArangoGraphQAChain
|
||||
from langchain.chains.graph_qa.base import GraphQAChain
|
||||
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
|
||||
from langchain.chains.graph_qa.falkordb import FalkorDBQAChain
|
||||
from langchain.chains.graph_qa.hugegraph import HugeGraphQAChain
|
||||
from langchain.chains.graph_qa.kuzu import KuzuQAChain
|
||||
from langchain.chains.graph_qa.nebulagraph import NebulaGraphQAChain
|
||||
@@ -85,6 +86,7 @@ __all__ = [
|
||||
"ConstitutionalChain",
|
||||
"ConversationChain",
|
||||
"ConversationalRetrievalChain",
|
||||
"FalkorDBQAChain",
|
||||
"FlareChain",
|
||||
"GraphCypherQAChain",
|
||||
"GraphQAChain",
|
||||
|
141
libs/langchain/langchain/chains/graph_qa/falkordb.py
Normal file
141
libs/langchain/langchain/chains/graph_qa/falkordb.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs import FalkorDBGraph
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.schema import BasePromptTemplate
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def extract_cypher(text: str) -> str:
|
||||
"""
|
||||
Extract Cypher code from a text.
|
||||
Args:
|
||||
text: Text to extract Cypher code from.
|
||||
|
||||
Returns:
|
||||
Cypher code extracted from the text.
|
||||
"""
|
||||
# The pattern to find Cypher code enclosed in triple backticks
|
||||
pattern = r"```(.*?)```"
|
||||
|
||||
# Find all matches in the input text
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
return matches[0] if matches else text
|
||||
|
||||
|
||||
class FalkorDBQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating Cypher statements."""
|
||||
|
||||
graph: FalkorDBGraph = Field(exclude=True)
|
||||
cypher_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 10
|
||||
"""Number of results to return from the query"""
|
||||
return_intermediate_steps: bool = False
|
||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||
return_direct: bool = False
|
||||
"""Whether or not to return the result of querying the graph directly."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "graph_cypher_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> FalkorDBQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
cypher_generation_chain=cypher_generation_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
generated_cypher = self.cypher_generation_chain.run(
|
||||
{"question": question, "schema": self.graph.schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
# Extract Cypher code if it is wrapped in backticks
|
||||
generated_cypher = extract_cypher(generated_cypher)
|
||||
|
||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_cypher})
|
||||
|
||||
# Retrieve and limit the number of results
|
||||
context = self.graph.query(generated_cypher)[: self.top_k]
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
@@ -1,6 +1,7 @@
|
||||
"""**Graphs** provide a natural language interface to graph databases."""
|
||||
|
||||
from langchain.graphs.arangodb_graph import ArangoGraph
|
||||
from langchain.graphs.falkordb_graph import FalkorDBGraph
|
||||
from langchain.graphs.hugegraph import HugeGraph
|
||||
from langchain.graphs.kuzu_graph import KuzuGraph
|
||||
from langchain.graphs.memgraph_graph import MemgraphGraph
|
||||
@@ -20,4 +21,5 @@ __all__ = [
|
||||
"HugeGraph",
|
||||
"RdfGraph",
|
||||
"ArangoGraph",
|
||||
"FalkorDBGraph",
|
||||
]
|
||||
|
67
libs/langchain/langchain/graphs/falkordb_graph.py
Normal file
67
libs/langchain/langchain/graphs/falkordb_graph.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
node_properties_query = """
|
||||
MATCH (n)
|
||||
UNWIND labels(n) as l
|
||||
UNWIND keys(n) as p
|
||||
RETURN {label:l, properties: collect(distinct p)} AS output
|
||||
"""
|
||||
|
||||
rel_properties_query = """
|
||||
MATCH ()-[r]->()
|
||||
UNWIND keys(r) as p
|
||||
RETURN {type:type(r), properties: collect(distinct p)} AS output
|
||||
"""
|
||||
|
||||
rel_query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
WITH labels(n)[0] AS src, labels(m)[0] AS dst, type(r) AS type
|
||||
RETURN DISTINCT "(:" + src + ")-[:" + type + "]->(:" + dst + ")" AS output
|
||||
"""
|
||||
|
||||
|
||||
class FalkorDBGraph:
|
||||
"""FalkorDB wrapper for graph operations."""
|
||||
|
||||
def __init__(
|
||||
self, database: str, host: str = "localhost", port: int = 6379
|
||||
) -> None:
|
||||
"""Create a new FalkorDB graph wrapper instance."""
|
||||
try:
|
||||
import redis
|
||||
from redis.commands.graph import Graph
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
self._driver = redis.Redis(host=host, port=port)
|
||||
self._graph = Graph(self._driver, database)
|
||||
|
||||
try:
|
||||
self.refresh_schema()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not refresh schema. Error: {e}")
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the FalkorDB database"""
|
||||
return self.schema
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""Refreshes the schema of the FalkorDB database"""
|
||||
self.schema = (
|
||||
f"Node properties: {node_properties_query}\n"
|
||||
f"Relationships properties: {rel_properties_query}\n"
|
||||
f"Relationships: {rel_query}\n"
|
||||
)
|
||||
|
||||
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||
"""Query FalkorDB database."""
|
||||
|
||||
try:
|
||||
data = self._graph.query(query, params)
|
||||
return data.result_set
|
||||
except Exception as e:
|
||||
raise ValueError("Generated Cypher Statement is not valid\n" f"{e}")
|
@@ -0,0 +1,34 @@
|
||||
import unittest
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain.graphs import FalkorDBGraph
|
||||
|
||||
|
||||
class TestFalkorDB(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.host = "localhost"
|
||||
self.graph = "test_falkordb"
|
||||
self.port = 6379
|
||||
|
||||
@patch("redis.Redis")
|
||||
def test_init(self, mock_client: Any) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
FalkorDBGraph(database=self.graph, host=self.host, port=self.port)
|
||||
|
||||
@patch("redis.Redis")
|
||||
def test_execute(self, mock_client: Any) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
graph = FalkorDBGraph(database=self.graph, host=self.host, port=self.port)
|
||||
|
||||
query = "RETURN 1"
|
||||
result = graph.query(query)
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
@patch("redis.Redis")
|
||||
def test_refresh_schema(self, mock_client: Any) -> None:
|
||||
mock_client.return_value = MagicMock()
|
||||
graph = FalkorDBGraph(database=self.graph, host=self.host, port=self.port)
|
||||
|
||||
graph.refresh_schema()
|
||||
self.assertNotEqual(graph.get_schema, "")
|
Reference in New Issue
Block a user