mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 20:58:25 +00:00
update multi index templates (#13569)
This commit is contained in:
parent
f4c0e3cc15
commit
790ed8be69
@ -1,3 +1,5 @@
|
|||||||
|
from operator import itemgetter
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
@ -11,7 +13,6 @@ from langchain.retrievers import (
|
|||||||
)
|
)
|
||||||
from langchain.schema import StrOutputParser
|
from langchain.schema import StrOutputParser
|
||||||
from langchain.schema.runnable import (
|
from langchain.schema.runnable import (
|
||||||
RunnableLambda,
|
|
||||||
RunnableParallel,
|
RunnableParallel,
|
||||||
RunnablePassthrough,
|
RunnablePassthrough,
|
||||||
)
|
)
|
||||||
@ -51,14 +52,6 @@ def fuse_retrieved_docs(input):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
retriever_map = {
|
|
||||||
"medical paper": pubmed,
|
|
||||||
"scientific paper": arxiv,
|
|
||||||
"public company finances report": sec,
|
|
||||||
"general": wiki,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def format_named_docs(named_docs):
|
def format_named_docs(named_docs):
|
||||||
return "\n\n".join(
|
return "\n\n".join(
|
||||||
f"Source: {source}\n\n{doc.page_content}" for source, doc in named_docs
|
f"Source: {source}\n\n{doc.page_content}" for source, doc in named_docs
|
||||||
@ -83,19 +76,26 @@ class Question(BaseModel):
|
|||||||
__root__: str
|
__root__: str
|
||||||
|
|
||||||
|
|
||||||
|
answer_chain = (
|
||||||
|
{
|
||||||
|
"question": itemgetter("question"),
|
||||||
|
"sources": lambda x: format_named_docs(x["sources"]),
|
||||||
|
}
|
||||||
|
| prompt
|
||||||
|
| ChatOpenAI(model="gpt-3.5-turbo-1106")
|
||||||
|
| StrOutputParser()
|
||||||
|
).with_config(run_name="answer")
|
||||||
chain = (
|
chain = (
|
||||||
(
|
(
|
||||||
RunnableParallel(
|
RunnableParallel(
|
||||||
{"question": RunnablePassthrough(), "sources": retrieve_all}
|
{"question": RunnablePassthrough(), "sources": retrieve_all}
|
||||||
).with_config(run_name="add_sources")
|
).with_config(run_name="add_sources")
|
||||||
| RunnablePassthrough.assign(
|
| RunnablePassthrough.assign(sources=fuse_retrieved_docs).with_config(
|
||||||
sources=(
|
run_name="fuse"
|
||||||
RunnableLambda(fuse_retrieved_docs) | format_named_docs
|
)
|
||||||
).with_config(run_name="fuse_and_format")
|
| RunnablePassthrough.assign(answer=answer_chain).with_config(
|
||||||
).with_config(run_name="update_sources")
|
run_name="add_answer"
|
||||||
| prompt
|
)
|
||||||
| ChatOpenAI(model="gpt-3.5-turbo-1106")
|
|
||||||
| StrOutputParser()
|
|
||||||
)
|
)
|
||||||
.with_config(run_name="QA with fused results")
|
.with_config(run_name="QA with fused results")
|
||||||
.with_types(input_type=Question)
|
.with_types(input_type=Question)
|
||||||
|
@ -28,7 +28,7 @@ wiki = WikipediaRetriever(top_k_results=5, doc_content_chars_max=2000).with_conf
|
|||||||
run_name="wiki"
|
run_name="wiki"
|
||||||
)
|
)
|
||||||
|
|
||||||
llm = ChatOpenAI(model="gpt-3.5-turbo-1106")
|
llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||||
|
|
||||||
|
|
||||||
class Search(BaseModel):
|
class Search(BaseModel):
|
||||||
@ -45,18 +45,29 @@ class Search(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
classifier = llm.bind(
|
retriever_name = {
|
||||||
functions=[convert_pydantic_to_openai_function(Search)],
|
"medical paper": "PubMed",
|
||||||
function_call={"name": "Search"},
|
"scientific paper": "ArXiv",
|
||||||
) | PydanticAttrOutputFunctionsParser(
|
"public company finances report": "SEC filings (Kay AI)",
|
||||||
pydantic_schema=Search, attr_name="question_resource"
|
"general": "Wikipedia",
|
||||||
|
}
|
||||||
|
|
||||||
|
classifier = (
|
||||||
|
llm.bind(
|
||||||
|
functions=[convert_pydantic_to_openai_function(Search)],
|
||||||
|
function_call={"name": "Search"},
|
||||||
|
)
|
||||||
|
| PydanticAttrOutputFunctionsParser(
|
||||||
|
pydantic_schema=Search, attr_name="question_resource"
|
||||||
|
)
|
||||||
|
| retriever_name.get
|
||||||
)
|
)
|
||||||
|
|
||||||
retriever_map = {
|
retriever_map = {
|
||||||
"medical paper": pubmed,
|
"PubMed": pubmed,
|
||||||
"scientific paper": arxiv,
|
"ArXiv": arxiv,
|
||||||
"public company finances report": sec,
|
"SEC filings (Kay AI)": sec,
|
||||||
"general": wiki,
|
"Wikipedia": wiki,
|
||||||
}
|
}
|
||||||
router_retriever = RouterRunnable(runnables=retriever_map)
|
router_retriever = RouterRunnable(runnables=retriever_map)
|
||||||
|
|
||||||
@ -79,17 +90,23 @@ class Question(BaseModel):
|
|||||||
__root__: str
|
__root__: str
|
||||||
|
|
||||||
|
|
||||||
|
retriever_chain = (
|
||||||
|
{"input": itemgetter("question"), "key": itemgetter("retriever_choice")}
|
||||||
|
| router_retriever
|
||||||
|
| format_docs
|
||||||
|
).with_config(run_name="retrieve")
|
||||||
|
answer_chain = (
|
||||||
|
{"sources": retriever_chain, "question": itemgetter("question")}
|
||||||
|
| prompt
|
||||||
|
| llm
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
chain = (
|
chain = (
|
||||||
(
|
(
|
||||||
RunnableParallel(
|
RunnableParallel(
|
||||||
{"input": RunnablePassthrough(), "key": classifier}
|
question=RunnablePassthrough(), retriever_choice=classifier
|
||||||
).with_config(run_name="classify")
|
).with_config(run_name="classify")
|
||||||
| RunnableParallel(
|
| RunnablePassthrough.assign(answer=answer_chain).with_config(run_name="answer")
|
||||||
{"question": itemgetter("input"), "sources": router_retriever | format_docs}
|
|
||||||
).with_config(run_name="retrieve")
|
|
||||||
| prompt
|
|
||||||
| llm
|
|
||||||
| StrOutputParser()
|
|
||||||
)
|
)
|
||||||
.with_config(run_name="QA with router")
|
.with_config(run_name="QA with router")
|
||||||
.with_types(input_type=Question)
|
.with_types(input_type=Question)
|
||||||
|
Loading…
Reference in New Issue
Block a user