diff --git a/langchain/chains/graph_qa/base.py b/langchain/chains/graph_qa/base.py index 44dceca8ed4..b194f2e146d 100644 --- a/langchain/chains/graph_qa/base.py +++ b/langchain/chains/graph_qa/base.py @@ -8,7 +8,7 @@ from pydantic import Field from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT +from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, GRAPH_QA_PROMPT from langchain.chains.llm import LLMChain from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities from langchain.schema import BasePromptTemplate @@ -44,7 +44,7 @@ class GraphQAChain(Chain): def from_llm( cls, llm: BaseLanguageModel, - qa_prompt: BasePromptTemplate = PROMPT, + qa_prompt: BasePromptTemplate = GRAPH_QA_PROMPT, entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT, **kwargs: Any, ) -> GraphQAChain: diff --git a/langchain/chains/graph_qa/prompts.py b/langchain/chains/graph_qa/prompts.py index ca68983b083..a0898e7aa66 100644 --- a/langchain/chains/graph_qa/prompts.py +++ b/langchain/chains/graph_qa/prompts.py @@ -23,14 +23,14 @@ ENTITY_EXTRACTION_PROMPT = PromptTemplate( input_variables=["input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE ) -prompt_template = """Use the following knowledge triplets to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. +_DEFAULT_GRAPH_QA_TEMPLATE = """Use the following knowledge triplets to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. {context} Question: {question} Helpful Answer:""" -PROMPT = PromptTemplate( - template=prompt_template, input_variables=["context", "question"] +GRAPH_QA_PROMPT = PromptTemplate( + template=_DEFAULT_GRAPH_QA_TEMPLATE, input_variables=["context", "question"] ) CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.