Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
d9f27e2b60 wip 2023-11-17 12:25:18 -08:00

View File

@@ -27,6 +27,21 @@ sec = KayAiRetriever.create(
wiki = WikipediaRetriever(top_k_results=5, doc_content_chars_max=2000).with_config(
run_name="wiki"
)
router_retriever = RouterRunnable(
runnables={
"medical paper": pubmed,
"scientific paper": arxiv,
"public company finances report": sec,
"general": wiki,
}
)
retriever_name = {
"medical paper": "PubMed",
"scientific paper": "ArXiv",
"public company finances report": "SEC Filings (Kay AI)",
"general": "Wikipedia",
}
llm = ChatOpenAI(model="gpt-3.5-turbo-1106")
@@ -52,14 +67,6 @@ classifier = llm.bind(
pydantic_schema=Search, attr_name="question_resource"
)
retriever_map = {
"medical paper": pubmed,
"scientific paper": arxiv,
"public company finances report": sec,
"general": wiki,
}
router_retriever = RouterRunnable(runnables=retriever_map)
def format_docs(docs):
return "\n\n".join(f"Source {i}:\n{doc.page_content}" for i, doc in enumerate(docs))
@@ -79,17 +86,18 @@ class Question(BaseModel):
__root__: str
answer_chain = (
{"question": itemgetter("input"), "sources": router_retriever | format_docs}
| prompt
| llm
| StrOutputParser()
)
chain = (
(
RunnableParallel(
{"input": RunnablePassthrough(), "key": classifier}
).with_config(run_name="classify")
| RunnableParallel(
{"question": itemgetter("input"), "sources": router_retriever | format_docs}
).with_config(run_name="retrieve")
| prompt
| llm
| StrOutputParser()
| {"answer": answer_chain, "source": lambda x: retriever_name[x["key"]]}
)
.with_config(run_name="QA with router")
.with_types(input_type=Question)