From 22bd12a097155df467a046afd4ca3006955fa042 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 21 Nov 2022 19:30:40 -0800 Subject: [PATCH] make prompt a variable in vector db qa (#170) --- langchain/chains/vector_db_qa/base.py | 7 +++++-- langchain/chains/vector_db_qa/prompt.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/vector_db_qa/base.py index d54de11ca22..e0e158c75f4 100644 --- a/langchain/chains/vector_db_qa/base.py +++ b/langchain/chains/vector_db_qa/base.py @@ -5,8 +5,9 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.chains.vector_db_qa.prompt import prompt +from langchain.chains.vector_db_qa.prompt import PROMPT from langchain.llms.base import LLM +from langchain.prompts import PromptTemplate from langchain.vectorstores.base import VectorStore @@ -29,6 +30,8 @@ class VectorDBQA(Chain, BaseModel): """Vector Database to connect to.""" k: int = 4 """Number of documents to query for.""" + prompt: PromptTemplate = PROMPT + """Prompt to use when questioning the documents.""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: @@ -56,7 +59,7 @@ class VectorDBQA(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: question = inputs[self.input_key] - llm_chain = LLMChain(llm=self.llm, prompt=prompt) + llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) docs = self.vectorstore.similarity_search(question, k=self.k) contexts = [] for j, doc in enumerate(docs): diff --git a/langchain/chains/vector_db_qa/prompt.py b/langchain/chains/vector_db_qa/prompt.py index 734223a3615..9ebb89eac92 100644 --- a/langchain/chains/vector_db_qa/prompt.py +++ b/langchain/chains/vector_db_qa/prompt.py @@ -7,6 +7,6 @@ prompt_template = """Use the following pieces of context to answer the question Question: {question} Helpful Answer:""" -prompt = PromptTemplate( +PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] )