update multi index templates (#13569)

This commit is contained in:
Bagatur
2023-11-18 14:42:22 -08:00
committed by GitHub
parent f4c0e3cc15
commit 790ed8be69
2 changed files with 51 additions and 34 deletions

View File

@@ -28,7 +28,7 @@ wiki = WikipediaRetriever(top_k_results=5, doc_content_chars_max=2000).with_conf
run_name="wiki"
)
llm = ChatOpenAI(model="gpt-3.5-turbo-1106")
llm = ChatOpenAI(model="gpt-3.5-turbo")
class Search(BaseModel):
@@ -45,18 +45,29 @@ class Search(BaseModel):
)
classifier = llm.bind(
functions=[convert_pydantic_to_openai_function(Search)],
function_call={"name": "Search"},
) | PydanticAttrOutputFunctionsParser(
pydantic_schema=Search, attr_name="question_resource"
retriever_name = {
"medical paper": "PubMed",
"scientific paper": "ArXiv",
"public company finances report": "SEC filings (Kay AI)",
"general": "Wikipedia",
}
classifier = (
llm.bind(
functions=[convert_pydantic_to_openai_function(Search)],
function_call={"name": "Search"},
)
| PydanticAttrOutputFunctionsParser(
pydantic_schema=Search, attr_name="question_resource"
)
| retriever_name.get
)
retriever_map = {
"medical paper": pubmed,
"scientific paper": arxiv,
"public company finances report": sec,
"general": wiki,
"PubMed": pubmed,
"ArXiv": arxiv,
"SEC filings (Kay AI)": sec,
"Wikipedia": wiki,
}
router_retriever = RouterRunnable(runnables=retriever_map)
@@ -79,17 +90,23 @@ class Question(BaseModel):
__root__: str
retriever_chain = (
{"input": itemgetter("question"), "key": itemgetter("retriever_choice")}
| router_retriever
| format_docs
).with_config(run_name="retrieve")
answer_chain = (
{"sources": retriever_chain, "question": itemgetter("question")}
| prompt
| llm
| StrOutputParser()
)
chain = (
(
RunnableParallel(
{"input": RunnablePassthrough(), "key": classifier}
question=RunnablePassthrough(), retriever_choice=classifier
).with_config(run_name="classify")
| RunnableParallel(
{"question": itemgetter("input"), "sources": router_retriever | format_docs}
).with_config(run_name="retrieve")
| prompt
| llm
| StrOutputParser()
| RunnablePassthrough.assign(answer=answer_chain).with_config(run_name="answer")
)
.with_config(run_name="QA with router")
.with_types(input_type=Question)