mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
Add option to save/load graph cypher QA (#6219)
Similar as https://github.com/hwchase17/langchain/pull/5818 Added the functionality to save/load Graph Cypher QA Chain due to a user reporting the following error > raise NotImplementedError("Saving not supported for this chain type.")\nNotImplementedError: Saving not supported for this chain type.\n'
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
import os
|
||||
|
||||
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
|
||||
from langchain.chains.loading import load_chain
|
||||
from langchain.graphs import Neo4jGraph
|
||||
from langchain.llms.openai import OpenAI
|
||||
|
||||
@@ -168,3 +169,31 @@ def test_cypher_return_direct() -> None:
|
||||
output = chain.run("Who played in Pulp Fiction?")
|
||||
expected_output = [{"a.name": "Bruce Willis"}]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_cypher_save_load() -> None:
|
||||
"""Test saving and loading."""
|
||||
|
||||
FILE_PATH = "cypher.yaml"
|
||||
|
||||
url = os.environ.get("NEO4J_URL")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
OpenAI(temperature=0), graph=graph, return_direct=True
|
||||
)
|
||||
|
||||
chain.save(file_path=FILE_PATH)
|
||||
qa_loaded = load_chain(FILE_PATH, graph=graph)
|
||||
|
||||
assert qa_loaded == chain
|
||||
|
Reference in New Issue
Block a user