mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
make prompt a variable in vector db qa (#170)
This commit is contained in:
parent
4a4dfbfbed
commit
22bd12a097
@ -5,8 +5,9 @@ from pydantic import BaseModel, Extra
|
|||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.vector_db_qa.prompt import prompt
|
from langchain.chains.vector_db_qa.prompt import PROMPT
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
|
|
||||||
@ -29,6 +30,8 @@ class VectorDBQA(Chain, BaseModel):
|
|||||||
"""Vector Database to connect to."""
|
"""Vector Database to connect to."""
|
||||||
k: int = 4
|
k: int = 4
|
||||||
"""Number of documents to query for."""
|
"""Number of documents to query for."""
|
||||||
|
prompt: PromptTemplate = PROMPT
|
||||||
|
"""Prompt to use when questioning the documents."""
|
||||||
input_key: str = "query" #: :meta private:
|
input_key: str = "query" #: :meta private:
|
||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
|
|
||||||
@ -56,7 +59,7 @@ class VectorDBQA(Chain, BaseModel):
|
|||||||
|
|
||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
question = inputs[self.input_key]
|
question = inputs[self.input_key]
|
||||||
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
|
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||||
docs = self.vectorstore.similarity_search(question, k=self.k)
|
docs = self.vectorstore.similarity_search(question, k=self.k)
|
||||||
contexts = []
|
contexts = []
|
||||||
for j, doc in enumerate(docs):
|
for j, doc in enumerate(docs):
|
||||||
|
@ -7,6 +7,6 @@ prompt_template = """Use the following pieces of context to answer the question
|
|||||||
|
|
||||||
Question: {question}
|
Question: {question}
|
||||||
Helpful Answer:"""
|
Helpful Answer:"""
|
||||||
prompt = PromptTemplate(
|
PROMPT = PromptTemplate(
|
||||||
template=prompt_template, input_variables=["context", "question"]
|
template=prompt_template, input_variables=["context", "question"]
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user