mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
Expose configuration options in GraphCypherQAChain (#12159)
Allows for passing arguments into the LLM chains used by the GraphCypherQAChain. This is to address a request by a user to include memory in the Cypher creating chain. Will keep the prompt variables as-is to be backward compatible. But, would be a good idea to deprecate them and use the **kwargs variables. Added a test case. In general, I think it would be good for any chain to automatically pass in a readonlymemory(of its input) to its subchains whilist allowing for an override. But, this would be a different change.
This commit is contained in:
parent
11f13aed53
commit
f09f82541b
@ -132,13 +132,15 @@ class GraphCypherQAChain(Chain):
|
|||||||
cls,
|
cls,
|
||||||
llm: Optional[BaseLanguageModel] = None,
|
llm: Optional[BaseLanguageModel] = None,
|
||||||
*,
|
*,
|
||||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
qa_prompt: Optional[BasePromptTemplate] = None,
|
||||||
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
|
cypher_prompt: Optional[BasePromptTemplate] = None,
|
||||||
cypher_llm: Optional[BaseLanguageModel] = None,
|
cypher_llm: Optional[BaseLanguageModel] = None,
|
||||||
qa_llm: Optional[BaseLanguageModel] = None,
|
qa_llm: Optional[BaseLanguageModel] = None,
|
||||||
exclude_types: List[str] = [],
|
exclude_types: List[str] = [],
|
||||||
include_types: List[str] = [],
|
include_types: List[str] = [],
|
||||||
validate_cypher: bool = False,
|
validate_cypher: bool = False,
|
||||||
|
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> GraphCypherQAChain:
|
) -> GraphCypherQAChain:
|
||||||
"""Initialize from LLM."""
|
"""Initialize from LLM."""
|
||||||
@ -152,9 +154,34 @@ class GraphCypherQAChain(Chain):
|
|||||||
"You can specify up to two of 'cypher_llm', 'qa_llm'"
|
"You can specify up to two of 'cypher_llm', 'qa_llm'"
|
||||||
", and 'llm', but not all three simultaneously."
|
", and 'llm', but not all three simultaneously."
|
||||||
)
|
)
|
||||||
|
if cypher_prompt and cypher_llm_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
"Specifying cypher_prompt and cypher_llm_kwargs together is"
|
||||||
|
" not allowed. Please pass prompt via cypher_llm_kwargs."
|
||||||
|
)
|
||||||
|
if qa_prompt and qa_llm_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
"Specifying qa_prompt and qa_llm_kwargs together is"
|
||||||
|
" not allowed. Please pass prompt via qa_llm_kwargs."
|
||||||
|
)
|
||||||
|
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
|
||||||
|
use_cypher_llm_kwargs = (
|
||||||
|
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
|
||||||
|
)
|
||||||
|
if "prompt" not in use_qa_llm_kwargs:
|
||||||
|
use_qa_llm_kwargs["prompt"] = (
|
||||||
|
qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT
|
||||||
|
)
|
||||||
|
if "prompt" not in use_cypher_llm_kwargs:
|
||||||
|
use_cypher_llm_kwargs["prompt"] = (
|
||||||
|
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
|
||||||
|
)
|
||||||
|
|
||||||
qa_chain = LLMChain(llm=qa_llm or llm, prompt=qa_prompt)
|
qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs)
|
||||||
cypher_generation_chain = LLMChain(llm=cypher_llm or llm, prompt=cypher_prompt)
|
|
||||||
|
cypher_generation_chain = LLMChain(
|
||||||
|
llm=cypher_llm or llm, **use_cypher_llm_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if exclude_types and include_types:
|
if exclude_types and include_types:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -1,9 +1,186 @@
|
|||||||
from typing import List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from langchain.chains.graph_qa.cypher import construct_schema, extract_cypher
|
from langchain.chains.graph_qa.cypher import (
|
||||||
|
GraphCypherQAChain,
|
||||||
|
construct_schema,
|
||||||
|
extract_cypher,
|
||||||
|
)
|
||||||
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
|
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
|
||||||
|
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
|
||||||
|
from langchain.graphs.graph_document import GraphDocument
|
||||||
|
from langchain.graphs.graph_store import GraphStore
|
||||||
|
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
class FakeGraphStore(GraphStore):
|
||||||
|
@property
|
||||||
|
def get_schema(self) -> str:
|
||||||
|
"""Returns the schema of the Graph database"""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_structured_schema(self) -> Dict[str, Any]:
|
||||||
|
"""Returns the schema of the Graph database"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||||
|
"""Query the graph."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def refresh_schema(self) -> None:
|
||||||
|
"""Refreshes the graph schema information."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def add_graph_documents(
|
||||||
|
self, graph_documents: List[GraphDocument], include_source: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""Take GraphDocument as input as uses it to construct a graph."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_cypher_qa_chain_prompt_selection_1() -> None:
|
||||||
|
# Pass prompts directly. No kwargs is specified.
|
||||||
|
qa_prompt_template = "QA Prompt"
|
||||||
|
cypher_prompt_template = "Cypher Prompt"
|
||||||
|
qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=[])
|
||||||
|
cypher_prompt = PromptTemplate(template=cypher_prompt_template, input_variables=[])
|
||||||
|
chain = GraphCypherQAChain.from_llm(
|
||||||
|
llm=FakeLLM(),
|
||||||
|
graph=FakeGraphStore(),
|
||||||
|
verbose=True,
|
||||||
|
return_intermediate_steps=False,
|
||||||
|
qa_prompt=qa_prompt,
|
||||||
|
cypher_prompt=cypher_prompt,
|
||||||
|
)
|
||||||
|
assert chain.qa_chain.prompt == qa_prompt
|
||||||
|
assert chain.cypher_generation_chain.prompt == cypher_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_cypher_qa_chain_prompt_selection_2() -> None:
|
||||||
|
# Default case. Pass nothing
|
||||||
|
chain = GraphCypherQAChain.from_llm(
|
||||||
|
llm=FakeLLM(),
|
||||||
|
graph=FakeGraphStore(),
|
||||||
|
verbose=True,
|
||||||
|
return_intermediate_steps=False,
|
||||||
|
)
|
||||||
|
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT
|
||||||
|
assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_cypher_qa_chain_prompt_selection_3() -> None:
|
||||||
|
# Pass non-prompt args only to sub-chains via kwargs
|
||||||
|
memory = ConversationBufferMemory(memory_key="chat_history")
|
||||||
|
readonlymemory = ReadOnlySharedMemory(memory=memory)
|
||||||
|
chain = GraphCypherQAChain.from_llm(
|
||||||
|
llm=FakeLLM(),
|
||||||
|
graph=FakeGraphStore(),
|
||||||
|
verbose=True,
|
||||||
|
return_intermediate_steps=False,
|
||||||
|
cypher_llm_kwargs={"memory": readonlymemory},
|
||||||
|
qa_llm_kwargs={"memory": readonlymemory},
|
||||||
|
)
|
||||||
|
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT
|
||||||
|
assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_cypher_qa_chain_prompt_selection_4() -> None:
|
||||||
|
# Pass prompt, non-prompt args to subchains via kwargs
|
||||||
|
qa_prompt_template = "QA Prompt"
|
||||||
|
cypher_prompt_template = "Cypher Prompt"
|
||||||
|
memory = ConversationBufferMemory(memory_key="chat_history")
|
||||||
|
readonlymemory = ReadOnlySharedMemory(memory=memory)
|
||||||
|
qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=[])
|
||||||
|
cypher_prompt = PromptTemplate(template=cypher_prompt_template, input_variables=[])
|
||||||
|
chain = GraphCypherQAChain.from_llm(
|
||||||
|
llm=FakeLLM(),
|
||||||
|
graph=FakeGraphStore(),
|
||||||
|
verbose=True,
|
||||||
|
return_intermediate_steps=False,
|
||||||
|
cypher_llm_kwargs={"prompt": cypher_prompt, "memory": readonlymemory},
|
||||||
|
qa_llm_kwargs={"prompt": qa_prompt, "memory": readonlymemory},
|
||||||
|
)
|
||||||
|
assert chain.qa_chain.prompt == qa_prompt
|
||||||
|
assert chain.cypher_generation_chain.prompt == cypher_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_cypher_qa_chain_prompt_selection_5() -> None:
|
||||||
|
# Can't pass both prompt and kwargs at the same time
|
||||||
|
qa_prompt_template = "QA Prompt"
|
||||||
|
cypher_prompt_template = "Cypher Prompt"
|
||||||
|
memory = ConversationBufferMemory(memory_key="chat_history")
|
||||||
|
readonlymemory = ReadOnlySharedMemory(memory=memory)
|
||||||
|
qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=[])
|
||||||
|
cypher_prompt = PromptTemplate(template=cypher_prompt_template, input_variables=[])
|
||||||
|
try:
|
||||||
|
GraphCypherQAChain.from_llm(
|
||||||
|
llm=FakeLLM(),
|
||||||
|
graph=FakeGraphStore(),
|
||||||
|
verbose=True,
|
||||||
|
return_intermediate_steps=False,
|
||||||
|
qa_prompt=qa_prompt,
|
||||||
|
cypher_prompt=cypher_prompt,
|
||||||
|
cypher_llm_kwargs={"memory": readonlymemory},
|
||||||
|
qa_llm_kwargs={"memory": readonlymemory},
|
||||||
|
)
|
||||||
|
assert False
|
||||||
|
except ValueError:
|
||||||
|
assert True
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_cypher_qa_chain() -> None:
|
||||||
|
template = """You are a nice chatbot having a conversation with a human.
|
||||||
|
|
||||||
|
Schema:
|
||||||
|
{schema}
|
||||||
|
|
||||||
|
Previous conversation:
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
New human question: {question}
|
||||||
|
Response:"""
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["schema", "question", "chat_history"], template=template
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = ConversationBufferMemory(memory_key="chat_history")
|
||||||
|
readonlymemory = ReadOnlySharedMemory(memory=memory)
|
||||||
|
prompt1 = (
|
||||||
|
"You are a nice chatbot having a conversation with a human.\n\n "
|
||||||
|
"Schema:\n Node properties are the following: \n {}\nRelationships "
|
||||||
|
"properties are the following: \n {}\nRelationships are: \n[]\n\n "
|
||||||
|
"Previous conversation:\n \n\n New human question: "
|
||||||
|
"Test question\n Response:"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt2 = (
|
||||||
|
"You are a nice chatbot having a conversation with a human.\n\n "
|
||||||
|
"Schema:\n Node properties are the following: \n {}\nRelationships "
|
||||||
|
"properties are the following: \n {}\nRelationships are: \n[]\n\n "
|
||||||
|
"Previous conversation:\n Human: Test question\nAI: foo\n\n "
|
||||||
|
"New human question: Test new question\n Response:"
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = FakeLLM(queries={prompt1: "answer1", prompt2: "answer2"})
|
||||||
|
chain = GraphCypherQAChain.from_llm(
|
||||||
|
cypher_llm=llm,
|
||||||
|
qa_llm=FakeLLM(),
|
||||||
|
graph=FakeGraphStore(),
|
||||||
|
verbose=True,
|
||||||
|
return_intermediate_steps=False,
|
||||||
|
cypher_llm_kwargs={"prompt": prompt, "memory": readonlymemory},
|
||||||
|
memory=memory,
|
||||||
|
)
|
||||||
|
chain.run("Test question")
|
||||||
|
chain.run("Test new question")
|
||||||
|
# If we get here without a key error, that means memory
|
||||||
|
# was used properly to create prompts.
|
||||||
|
assert True
|
||||||
|
|
||||||
|
|
||||||
def test_no_backticks() -> None:
|
def test_no_backticks() -> None:
|
||||||
|
@ -39,9 +39,11 @@ class FakeLLM(LLM):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
print(prompt)
|
||||||
if self.sequential_responses:
|
if self.sequential_responses:
|
||||||
return self._get_next_response_in_sequence
|
return self._get_next_response_in_sequence
|
||||||
|
print(repr(prompt))
|
||||||
|
print(self.queries)
|
||||||
if self.queries is not None:
|
if self.queries is not None:
|
||||||
return self.queries[prompt]
|
return self.queries[prompt]
|
||||||
if stop is None:
|
if stop is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user